diff --git a/.github/workflows/fast_tests.yml b/.github/workflows/fast_tests.yml
index 990de2ccfb..b68d4b2dbb 100644
--- a/.github/workflows/fast_tests.yml
+++ b/.github/workflows/fast_tests.yml
@@ -15,8 +15,7 @@ concurrency:
jobs:
transformers:
name: Run tests for optimum.habana.transformers
- runs-on:
- group: aws-dl1-24xlarge
+ runs-on: [self-hosted, linux, x64, gaudi2]
steps:
- name: Checkout
uses: actions/checkout@v2
@@ -38,8 +37,7 @@ jobs:
/bin/bash tests/ci/fast_tests.sh
diffusers:
name: Run tests for optimum.habana.diffusers
- runs-on:
- group: aws-dl1-24xlarge
+ runs-on: [self-hosted, linux, x64, gaudi2]
steps:
- name: Checkout
uses: actions/checkout@v2
diff --git a/Makefile b/Makefile
index 636ce76a04..e6989aa1b0 100644
--- a/Makefile
+++ b/Makefile
@@ -35,6 +35,8 @@ style: clean
fast_tests:
python -m pip install .[tests]
python -m pytest tests/test_gaudi_configuration.py tests/test_trainer_distributed.py tests/test_trainer.py tests/test_trainer_seq2seq.py
+# TODO enable when CI has more servers
+# python -m pytest test_functional_text_generation_example.py
# Run unit and integration tests related to Diffusers
fast_tests_diffusers:
diff --git a/README.md b/README.md
index 2219cf5e3d..11811eb065 100644
--- a/README.md
+++ b/README.md
@@ -214,6 +214,8 @@ The following model architectures, tasks and device distributions have been vali
| Qwen2 |
Single card | Single card | [language modeling](https://github.com/huggingface/optimum-habana/tree/main/examples/language-modeling)[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation) |
| Qwen2-MoE | | Single card | [text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation) |
| Gemma | :heavy_check_mark: | Single card | [language modeling](https://github.com/huggingface/optimum-habana/tree/main/examples/language-modeling)[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation) |
+| XGLM | | Single card | [text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation) |
+| Cohere | | Single card | [text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation) |
| T5 / Flan T5 | :heavy_check_mark: | :heavy_check_mark: | [summarization](https://github.com/huggingface/optimum-habana/tree/main/examples/summarization)[translation](https://github.com/huggingface/optimum-habana/tree/main/examples/translation)[question answering](https://github.com/huggingface/optimum-habana/tree/main/examples/question-answering#fine-tuning-t5-on-squad20) |
| BART | | Single card | [summarization](https://github.com/huggingface/optimum-habana/tree/main/examples/summarization)[translation](https://github.com/huggingface/optimum-habana/tree/main/examples/translation)[question answering](https://github.com/huggingface/optimum-habana/tree/main/examples/question-answering#fine-tuning-t5-on-squad20) |
| ViT | :heavy_check_mark: | :heavy_check_mark: | [image classification](https://github.com/huggingface/optimum-habana/tree/main/examples/image-classification) |
@@ -228,10 +230,12 @@ The following model architectures, tasks and device distributions have been vali
| OWLViT | | Single card | [zero shot object detection](https://github.com/huggingface/optimum-habana/tree/main/examples/zero-shot-object-detection) |
| ClipSeg | | Single card | [object segmentation](https://github.com/huggingface/optimum-habana/tree/main/examples/object-segementation) |
| Llava / Llava-next | | Single card | [image to text](https://github.com/huggingface/optimum-habana/tree/main/examples/image-to-text) |
+| idefics2 | LoRA | Single card | [image to text](https://github.com/huggingface/optimum-habana/tree/main/examples/image-to-text) |
| Segment Anything Model | | Single card | [object segmentation](https://github.com/huggingface/optimum-habana/tree/main/examples/object-segementation) |
| VideoMAE | | Single card | [Video classification](https://github.com/huggingface/optimum-habana/tree/main/examples/video-classification) |
| TableTransformer | | Single card | [table object detection](https://github.com/huggingface/optimum-habana/tree/main/examples/table-detection) |
| DETR | | Single card | [object detection](https://github.com/huggingface/optimum-habana/tree/main/examples/object-detection) |
+| Mllama | LoRA | :heavy_check_mark: | [image to text](https://github.com/huggingface/optimum-habana/tree/main/examples/image-to-text) |
diff --git a/docs/source/index.mdx b/docs/source/index.mdx
index 947e41b6fe..2b0606364e 100644
--- a/docs/source/index.mdx
+++ b/docs/source/index.mdx
@@ -60,6 +60,8 @@ In the tables below, ✅ means single-card, multi-card and DeepSpeed have all be
| Qwen2 | Single card | Single card | [language modeling](https://github.com/huggingface/optimum-habana/tree/main/examples/language-modeling)[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation) |
| Qwen2-MoE | | Single card | [text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation) |
| Persimmon | | Single card | [text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation) |
+| XGLM | | Single card | [text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation) |
+| Cohere | | Single card | [text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation) |
| T5 / Flan T5 | ✅ | ✅ | [summarization](https://github.com/huggingface/optimum-habana/tree/main/examples/summarization)[translation](https://github.com/huggingface/optimum-habana/tree/main/examples/translation)[question answering](https://github.com/huggingface/optimum-habana/tree/main/examples/question-answering#fine-tuning-t5-on-squad20) |
| BART | | Single card | [summarization](https://github.com/huggingface/optimum-habana/tree/main/examples/summarization)[translation](https://github.com/huggingface/optimum-habana/tree/main/examples/translation)[question answering](https://github.com/huggingface/optimum-habana/tree/main/examples/question-answering#fine-tuning-t5-on-squad20) |
| ViT | ✅ | ✅ | [image classification](https://github.com/huggingface/optimum-habana/tree/main/examples/image-classification) |
@@ -74,10 +76,12 @@ In the tables below, ✅ means single-card, multi-card and DeepSpeed have all be
| OWLViT | | Single card | [zero shot object detection](https://github.com/huggingface/optimum-habana/tree/main/examples/zero-shot-object-detection) |
| ClipSeg | | Single card | [object segmentation](https://github.com/huggingface/optimum-habana/tree/main/examples/object-segementation) |
| Llava / Llava-next | | Single card | [image to text](https://github.com/huggingface/optimum-habana/tree/main/examples/image-to-text) |
+| idefics2 | LoRA | Single card | [image to text](https://github.com/huggingface/optimum-habana/tree/main/examples/image-to-text) |
| SAM | | Single card | [object segmentation](https://github.com/huggingface/optimum-habana/tree/main/examples/object-segementation) |
| VideoMAE | | Single card | [Video classification](https://github.com/huggingface/optimum-habana/tree/main/examples/video-classification) |
| TableTransformer | | Single card | [table object detection](https://github.com/huggingface/optimum-habana/tree/main/examples/table-detection) |
| DETR | | Single card | [object detection](https://github.com/huggingface/optimum-habana/tree/main/examples/object-detection) |
+| Mllama | LoRA |✅ | [image to text](https://github.com/huggingface/optimum-habana/tree/main/examples/image-to-text) |
- Diffusers
diff --git a/docs/source/package_reference/gaudi_config.mdx b/docs/source/package_reference/gaudi_config.mdx
index 1060e9c64e..a7b9f077b5 100644
--- a/docs/source/package_reference/gaudi_config.mdx
+++ b/docs/source/package_reference/gaudi_config.mdx
@@ -20,8 +20,8 @@ Here is a description of each configuration parameter:
- `use_fused_adam` enables to decide whether to use the [custom fused implementation of the ADAM optimizer provided by Intel® Gaudi® AI Accelerator](https://docs.habana.ai/en/latest/PyTorch/Model_Optimization_PyTorch/Custom_Ops_PyTorch.html#custom-optimizers).
- `use_fused_clip_norm` enables to decide whether to use the [custom fused implementation of gradient norm clipping provided by Intel® Gaudi® AI Accelerator](https://docs.habana.ai/en/latest/PyTorch/Model_Optimization_PyTorch/Custom_Ops_PyTorch.html#other-custom-ops).
- `use_torch_autocast` enables PyTorch autocast; used to define good pre-defined config; users should favor `--bf16` training argument
-- `autocast_bf16_ops` list of operations that should be run with bf16 precision under autocast context; using environment flag LOWER_LIST is a preffered way for operator autocast list override
-- `autocast_fp32_ops` list of operations that should be run with fp32 precision under autocast context; using environment flag FP32_LIST is a preffered way for operator autocast list override
+- `autocast_bf16_ops` list of operations that should be run with bf16 precision under autocast context; using environment flag PT_HPU_AUTOCAST_LOWER_PRECISION_OPS_LIST is a preffered way for operator autocast list override
+- `autocast_fp32_ops` list of operations that should be run with fp32 precision under autocast context; using environment flag PT_HPU_AUTOCAST_FP32_OPS_LIST is a preffered way for operator autocast list override
You can find examples of Gaudi configurations in the [Habana model repository on the Hugging Face Hub](https://huggingface.co/habana). For instance, [for BERT Large we have](https://huggingface.co/Habana/bert-large-uncased-whole-word-masking/blob/main/gaudi_config.json):
diff --git a/examples/image-classification/README.md b/examples/image-classification/README.md
index 4118195015..08c4d67123 100644
--- a/examples/image-classification/README.md
+++ b/examples/image-classification/README.md
@@ -33,7 +33,7 @@ pip install -r requirements.txt
Here we show how to fine-tune a Vision Transformer (`ViT`) on Cifar10:
```bash
-python run_image_classification.py \
+PT_HPU_LAZY_MODE=0 python run_image_classification.py \
--model_name_or_path google/vit-base-patch16-224-in21k \
--dataset_name cifar10 \
--output_dir /tmp/outputs/ \
@@ -51,10 +51,11 @@ python run_image_classification.py \
--save_total_limit 3 \
--seed 1337 \
--use_habana \
- --use_lazy_mode \
- --use_hpu_graphs_for_inference \
+ --use_lazy_mode False \
+ --torch_compile_backend hpu_backend \
+ --torch_compile \
--gaudi_config_name Habana/vit \
- --throughput_warmup_steps 3 \
+ --throughput_warmup_steps 6 \
--dataloader_num_workers 1 \
--bf16
```
@@ -92,7 +93,7 @@ root/cat/[...]/asd932_.png
In other words, you need to organize your images in subfolders, based on their class. You can then run the script like this:
```bash
-python run_image_classification.py \
+PT_HPU_LAZY_MODE=0 python run_image_classification.py \
--model_name_or_path google/vit-base-patch16-224-in21k \
--train_dir \
--output_dir /tmp/outputs/ \
@@ -100,8 +101,9 @@ python run_image_classification.py \
--do_train \
--do_eval \
--use_habana \
- --use_lazy_mode \
- --use_hpu_graphs_for_inference \
+ --use_lazy_mode False \
+ --torch_compile_backend hpu_backend \
+ --torch_compile \
--gaudi_config_name Habana/vit \
--throughput_warmup_steps 3 \
--dataloader_num_workers 1 \
@@ -184,7 +186,7 @@ python run_image_classification.py \
Here is how you would fine-tune ViT on Cifar10 using 8 HPUs:
```bash
-python ../gaudi_spawn.py \
+PT_HPU_LAZY_MODE=0 python ../gaudi_spawn.py \
--world_size 8 --use_mpi run_image_classification.py \
--model_name_or_path google/vit-base-patch16-224-in21k \
--dataset_name cifar10 \
@@ -203,8 +205,9 @@ python ../gaudi_spawn.py \
--save_total_limit 3 \
--seed 1337 \
--use_habana \
- --use_lazy_mode \
- --use_hpu_graphs_for_inference \
+ --use_lazy_mode False \
+ --torch_compile_backend hpu_backend \
+ --torch_compile \
--gaudi_config_name Habana/vit \
--throughput_warmup_steps 8 \
--dataloader_num_workers 1 \
@@ -224,7 +227,7 @@ For Swin, you need to change/add the following arguments:
Similarly to multi-HPU training, here is how you would fine-tune ViT on Cifar10 using 8 HPUs with DeepSpeed:
```bash
-python ../gaudi_spawn.py \
+PT_HPU_LAZY_MODE=0 python ../gaudi_spawn.py \
--world_size 8 --use_deepspeed run_image_classification.py \
--model_name_or_path google/vit-base-patch16-224-in21k \
--dataset_name cifar10 \
@@ -243,8 +246,9 @@ python ../gaudi_spawn.py \
--save_total_limit 3 \
--seed 1337 \
--use_habana \
- --use_lazy_mode \
- --use_hpu_graphs_for_inference \
+ --use_lazy_mode False \
+ --torch_compile_backend hpu_backend \
+ --torch_compile \
--gaudi_config_name Habana/vit \
--throughput_warmup_steps 3 \
--dataloader_num_workers 1 \
diff --git a/examples/image-to-text/README.md b/examples/image-to-text/README.md
index b5e261f32a..5916de4a29 100644
--- a/examples/image-to-text/README.md
+++ b/examples/image-to-text/README.md
@@ -30,6 +30,8 @@ Models that have been validated:
- [llava-hf/llava-v1.6-vicuna-13b-hf](https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf)
- [llava-hf/llava-v1.6-34b-hf](https://huggingface.co/llava-hf/llava-v1.6-34b-hf)
- [llava-hf/llama3-llava-next-8b-hf](https://huggingface.co/llava-hf/llama3-llava-next-8b-hf)
+ - [HuggingFaceM4/idefics2-8b](https://huggingface.co/HuggingFaceM4/idefics2-8b)
+ - [meta-llama/Llama-3.2-11B-Vision-Instruct](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct)
### Inference with BF16
@@ -92,6 +94,24 @@ python3 run_pipeline.py \
--bf16
```
+To run idefics2 inference, use the following command:
+
+```bash
+python3 run_pipeline.py \
+ --model_name_or_path HuggingFaceM4/idefics2-8b \
+ --use_hpu_graphs \
+ --bf16
+```
+
+To run mllama inference, use the following command:
+
+```bash
+python3 run_pipeline.py \
+ --model_name_or_path meta-llama/Llama-3.2-11B-Vision-Instruct \
+ --use_hpu_graphs \
+ --bf16
+```
+
### Inference with FP8
Inference for Llava-1.5-7b, Llava-1.5-13b, Llava-v1.6-mistral-7b and Llava-v1.6-vicuna-13b in FP8 precision are enabled using [Intel Neural Compressor (INC)](https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_FP8.html), which provides model measurement and quantization capabilities in PyTorch.
@@ -101,56 +121,56 @@ https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_FP
Here is an example to measure the tensor quantization statistics on Llava-1.5-7b:
```bash
QUANT_CONFIG=./quantization_config/maxabs_measure.json python run_pipeline.py \
---model_name_or_path llava-hf/llava-1.5-7b-hf \
---image_path "https://llava-vl.github.io/static/images/view.jpg" \
---use_hpu_graphs \
---bf16
+ --model_name_or_path llava-hf/llava-1.5-7b-hf \
+ --image_path "https://llava-vl.github.io/static/images/view.jpg" \
+ --use_hpu_graphs \
+ --bf16
```
Here is an example to quantize the model based on previous measurements for Llava-1.5-7b:
```bash
QUANT_CONFIG=./quantization_config/maxabs_quant.json python run_pipeline.py \
---model_name_or_path llava-hf/llava-1.5-7b-hf \
---image_path "https://llava-vl.github.io/static/images/view.jpg" \
---use_hpu_graphs \
---bf16
+ --model_name_or_path llava-hf/llava-1.5-7b-hf \
+ --image_path "https://llava-vl.github.io/static/images/view.jpg" \
+ --use_hpu_graphs \
+ --bf16
```
Here is an example to measure the tensor quantization statistics on Llava-v1.6-mistral-7b:
```bash
QUANT_CONFIG=./quantization_config/maxabs_measure.json python run_pipeline.py \
---model_name_or_path llava-hf/llava-v1.6-mistral-7b-hf \
---image_path "https://llava-vl.github.io/static/images/view.jpg" \
---use_hpu_graphs \
---bf16
+ --model_name_or_path llava-hf/llava-v1.6-mistral-7b-hf \
+ --image_path "https://llava-vl.github.io/static/images/view.jpg" \
+ --use_hpu_graphs \
+ --bf16
```
Here is an example to quantize the model based on previous measurements for Llava-v1.6-mistral-7b:
```bash
QUANT_CONFIG=./quantization_config/maxabs_quant.json python run_pipeline.py \
---model_name_or_path llava-hf/llava-v1.6-mistral-7b-hf \
---image_path "https://llava-vl.github.io/static/images/view.jpg" \
---use_hpu_graphs \
---bf16
+ --model_name_or_path llava-hf/llava-v1.6-mistral-7b-hf \
+ --image_path "https://llava-vl.github.io/static/images/view.jpg" \
+ --use_hpu_graphs \
+ --bf16
```
Here is an example to measure the tensor quantization statistics on Llava-v1.6-vicuna-13b:
```bash
QUANT_CONFIG=./quantization_config/maxabs_measure.json python run_pipeline.py \
---model_name_or_path llava-hf/llava-v1.6-vicuna-13b-hf \
---image_path "https://llava-vl.github.io/static/images/view.jpg" \
---use_hpu_graphs \
---bf16
+ --model_name_or_path llava-hf/llava-v1.6-vicuna-13b-hf \
+ --image_path "https://llava-vl.github.io/static/images/view.jpg" \
+ --use_hpu_graphs \
+ --bf16
```
Here is an example to quantize the model based on previous measurements for Llava-v1.6-vicuna-13b:
```bash
QUANT_CONFIG=./quantization_config/maxabs_quant.json python run_pipeline.py \
---model_name_or_path llava-hf/llava-v1.6-vicuna-13b-hf \
---image_path "https://llava-vl.github.io/static/images/view.jpg" \
---use_hpu_graphs \
---bf16
+ --model_name_or_path llava-hf/llava-v1.6-vicuna-13b-hf \
+ --image_path "https://llava-vl.github.io/static/images/view.jpg" \
+ --use_hpu_graphs \
+ --bf16
```
### Inference with FusedSDPA
@@ -186,6 +206,174 @@ Use the following commands to run Llava-v1.6-mistral-7b FP8 inference with Fused
Here is an example of measuring the tensor quantization statistics on Llava-v1.6-mistral-7b:
```bash
QUANT_CONFIG=./quantization_config/maxabs_measure.json python run_pipeline.py \
+ --model_name_or_path llava-hf/llava-v1.6-mistral-7b-hf \
+ --image_path "https://llava-vl.github.io/static/images/view.jpg" \
+ --use_hpu_graphs \
+ --bf16 \
+ --use_flash_attention \
+ --flash_attention_recompute
+```
+
+Here is an example of quantizing the model based on previous measurements for Llava-v1.6-mistral-7b:
+```bash
+QUANT_CONFIG=./quantization_config/maxabs_quant.json python run_pipeline.py \
+ --model_name_or_path llava-hf/llava-v1.6-mistral-7b-hf \
+ --image_path "https://llava-vl.github.io/static/images/view.jpg" \
+ --use_hpu_graphs \
+ --bf16 \
+ --use_flash_attention \
+ --flash_attention_recompute
+```
+## LORA Finetune
+
+To run LoRA finetuning, you can use `run_image2text_lora_finetune.py`.
+Here are single-/multi-device command examples for HuggingFaceM4/idefics2-8b.
+
+```bash
+python3 run_image2text_lora_finetune.py \
+ --model_name_or_path HuggingFaceM4/idefics2-8b \
+ --dataset_name nielsr/docvqa_1200_examples \
+ --bf16 True \
+ --output_dir ./model_lora_llama \
+ --num_train_epochs 1 \
+ --per_device_train_batch_size 2 \
+ --per_device_eval_batch_size 2 \
+ --gradient_accumulation_steps 8 \
+ --weight_decay 0.01 \
+ --logging_steps 25 \
+ --eval_strategy "no" \
+ --save_strategy "no" \
+ --learning_rate 5e-5 \
+ --warmup_steps 50 \
+ --lr_scheduler_type "constant" \
+ --input_column_names 'image' 'query' \
+ --output_column_names 'answers' \
+ --remove_unused_columns False \
+ --do_train \
+ --do_eval \
+ --use_habana \
+ --use_lazy_mode \
+ --lora_rank=8 \
+ --lora_alpha=8 \
+ --lora_dropout=0.1 \
+ --max_seq_length=512 \
+ --use_hpu_graphs_for_inference \
+ --low_cpu_mem_usage True \
+ --lora_target_modules '.*(text_model|modality_projection|perceiver_resampler).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$'
+```
+
+```bash
+python3 ../gaudi_spawn.py \
+ --world_size 8 --use_mpi run_image2text_lora_finetune.py \
+ --model_name_or_path HuggingFaceM4/idefics2-8b \
+ --dataset_name nielsr/docvqa_1200_examples \
+ --bf16 True \
+ --output_dir ./model_lora_llama \
+ --num_train_epochs 1 \
+ --per_device_train_batch_size 2 \
+ --per_device_eval_batch_size 2 \
+ --gradient_accumulation_steps 8 \
+ --weight_decay 0.01 \
+ --logging_steps 25 \
+ --eval_strategy "no" \
+ --save_strategy "no" \
+ --learning_rate 5e-5 \
+ --warmup_steps 50 \
+ --lr_scheduler_type "constant" \
+ --input_column_names 'image' 'query' \
+ --output_column_names 'answers' \
+ --remove_unused_columns False \
+ --do_train \
+ --do_eval \
+ --use_habana \
+ --use_lazy_mode \
+ --lora_rank=8 \
+ --lora_alpha=8 \
+ --lora_dropout=0.1 \
+ --max_seq_length=512 \
+ --use_hpu_graphs_for_inference \
+ --low_cpu_mem_usage True \
+ --lora_target_modules '".*(text_model|modality_projection|perceiver_resampler).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$"'
+```
+
+Here are single-/multi-device command examples for meta-llama/Llama-3.2-11B-Vision-Instruct.
+
+```bash
+python3 run_image2text_lora_finetune.py \
+ --model_name_or_path meta-llama/Llama-3.2-11B-Vision-Instruct \
+ --dataset_name nielsr/docvqa_1200_examples \
+ --bf16 True \
+ --output_dir ./model_lora_llama \
+ --num_train_epochs 2 \
+ --per_device_train_batch_size 2 \
+ --per_device_eval_batch_size 2 \
+ --gradient_accumulation_steps 8 \
+ --weight_decay 0.01 \
+ --logging_steps 25 \
+ --eval_strategy "no" \
+ --save_strategy "no" \
+ --learning_rate 5e-5 \
+ --warmup_steps 50 \
+ --lr_scheduler_type "constant" \
+ --input_column_names 'image' 'query' \
+ --output_column_names 'answers' \
+ --remove_unused_columns False \
+ --do_train \
+ --do_eval \
+ --use_habana \
+ --use_lazy_mode \
+ --lora_rank=8 \
+ --lora_alpha=8 \
+ --lora_dropout=0.1 \
+ --low_cpu_mem_usage True \
+ --max_seq_length=512 \
+ --use_hpu_graphs_for_inference True \
+ --lora_target_modules ".*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$"
+```
+
+```bash
+python3 ../gaudi_spawn.py \
+ --world_size 8 --use_mpi run_image2text_lora_finetune.py \
+ --model_name_or_path meta-llama/Llama-3.2-11B-Vision-Instruct \
+ --dataset_name nielsr/docvqa_1200_examples \
+ --bf16 True \
+ --output_dir ./model_lora_llama \
+ --num_train_epochs 2 \
+ --per_device_train_batch_size 2 \
+ --per_device_eval_batch_size 2 \
+ --gradient_accumulation_steps 8 \
+ --weight_decay 0.01 \
+ --logging_steps 25 \
+ --eval_strategy "no" \
+ --save_strategy "no" \
+ --learning_rate 5e-5 \
+ --warmup_steps 50 \
+ --lr_scheduler_type "constant" \
+ --input_column_names 'image' 'query' \
+ --output_column_names 'answers' \
+ --remove_unused_columns False \
+ --do_train \
+ --do_eval \
+ --use_habana \
+ --use_lazy_mode \
+ --lora_rank=8 \
+ --lora_alpha=8 \
+ --lora_dropout=0.1 \
+ --low_cpu_mem_usage True \
+ --max_seq_length=512 \
+ --use_hpu_graphs_for_inference True \
+ --lora_target_modules '".*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$"'
+```
+
+## Multi-HPU inference
+
+To enable multi-card inference, you must set the environment variable `PT_HPU_ENABLE_LAZY_COLLECTIVES=true`,
+
+### BF16 Inference with FusedSDPA on 8 HPUs
+
+Use the following commands to run Llava-v1.6-mistral-7b BF16 inference with FusedSDPA on 8 HPUs:
+```bash
+PT_HPU_ENABLE_LAZY_COLLECTIVES=true python ../gaudi_spawn.py --use_deepspeed --world_size 8 run_pipeline.py \
--model_name_or_path llava-hf/llava-v1.6-mistral-7b-hf \
--image_path "https://llava-vl.github.io/static/images/view.jpg" \
--use_hpu_graphs \
@@ -194,9 +382,12 @@ QUANT_CONFIG=./quantization_config/maxabs_measure.json python run_pipeline.py \
--flash_attention_recompute
```
-Here is an example of quantizing the model based on previous measurements for Llava-v1.6-mistral-7b:
+### FP8 Inference with FusedSDPA on 8 HPUs
+
+Use the following commands to run Llava-v1.6-mistral-7b FP8 inference with FusedSDPA on 8 HPUs.
+Here is an example of measuring the tensor quantization statistics on Llava-v1.6-mistral-7b on 8 HPUs:
```bash
-QUANT_CONFIG=./quantization_config/maxabs_quant.json python run_pipeline.py \
+QUANT_CONFIG=./quantization_config/maxabs_measure.json PT_HPU_ENABLE_LAZY_COLLECTIVES=true python ../gaudi_spawn.py --use_deepspeed --world_size 8 run_pipeline.py \
--model_name_or_path llava-hf/llava-v1.6-mistral-7b-hf \
--image_path "https://llava-vl.github.io/static/images/view.jpg" \
--use_hpu_graphs \
@@ -204,3 +395,14 @@ QUANT_CONFIG=./quantization_config/maxabs_quant.json python run_pipeline.py \
--use_flash_attention \
--flash_attention_recompute
```
+
+Here is an example of quantizing the model based on previous measurements for Llava-v1.6-mistral-7b on 8 HPUs:
+```bash
+QUANT_CONFIG=./quantization_config/maxabs_quant.json PT_HPU_ENABLE_LAZY_COLLECTIVES=true python ../gaudi_spawn.py --use_deepspeed --world_size 8 run_pipeline.py \
+--model_name_or_path llava-hf/llava-v1.6-mistral-7b-hf \
+--image_path "https://llava-vl.github.io/static/images/view.jpg" \
+--use_hpu_graphs \
+--bf16 \
+--use_flash_attention \
+--flash_attention_recompute
+```
\ No newline at end of file
diff --git a/examples/image-to-text/requirements.txt b/examples/image-to-text/requirements.txt
new file mode 100644
index 0000000000..045bd69e64
--- /dev/null
+++ b/examples/image-to-text/requirements.txt
@@ -0,0 +1,2 @@
+peft == 0.12.0
+Levenshtein
diff --git a/examples/image-to-text/run_image2text_lora_finetune.py b/examples/image-to-text/run_image2text_lora_finetune.py
new file mode 100644
index 0000000000..ded60e6d52
--- /dev/null
+++ b/examples/image-to-text/run_image2text_lora_finetune.py
@@ -0,0 +1,553 @@
+# Apache v2 license
+# Copyright (C) 2024 Intel Corporation
+# SPDX-License-Identifier: Apache-2.0
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+lora fine tuning script for image-to-text case
+Adapted from the following sources:
+https://colab.research.google.com/drive/1rm3AGquGEYXfeeizE40bbDtcWh5S4Nlq?usp=sharing
+"""
+
+import logging
+import os
+import random
+import sys
+from dataclasses import dataclass, field
+from typing import List, Optional
+
+import Levenshtein
+import torch
+import transformers
+from datasets import load_dataset
+from peft import LoraConfig, get_peft_model
+from transformers import (
+ AutoConfig,
+ AutoModelForVision2Seq,
+ AutoProcessor,
+ HfArgumentParser,
+)
+from transformers.trainer_utils import is_main_process
+
+from optimum.habana import GaudiConfig, GaudiTrainer, GaudiTrainingArguments
+
+
+try:
+ from optimum.habana.utils import check_optimum_habana_min_version
+except ImportError:
+
+ def check_optimum_habana_min_version(*a, **b):
+ return ()
+
+
+os.environ["WANDB_DISABLED"] = "true"
+
+logger = logging.getLogger(__name__)
+
+# Will error if the minimal version of Optimum Habana is not installed. Remove at your own risks.
+check_optimum_habana_min_version("1.10.0")
+
+
+def normalized_levenshtein(s1, s2):
+ len_s1, len_s2 = len(s1), len(s2)
+ distance = Levenshtein.distance(s1, s2)
+ return distance / max(len_s1, len_s2)
+
+
+def similarity_score(a_ij, o_q_i, tau=0.5):
+ nl = normalized_levenshtein(a_ij, o_q_i)
+ return 1 - nl if nl < tau else 0
+
+
+def average_normalized_levenshtein_similarity(ground_truth, predicted_answers):
+ assert len(ground_truth) == len(predicted_answers), "Length of ground_truth and predicted_answers must match."
+
+ N = len(ground_truth)
+ total_score = 0
+
+ for i in range(N):
+ a_i = ground_truth[i]
+ o_q_i = predicted_answers[i]
+ if o_q_i == "":
+ print("Warning: Skipped an empty prediction.")
+ max_score = 0
+ else:
+ max_score = max(similarity_score(a_ij, o_q_i) for a_ij in a_i)
+
+ total_score += max_score
+
+ return total_score / N
+
+
+@dataclass
+class ModelArguments:
+ """
+ Arguments pertaining to which model/config/processor we are going to fine-tune, or train from scratch.
+ """
+
+ model_name_or_path: Optional[str] = field(
+ default=None,
+ metadata={
+ "help": "The model checkpoint for weights initialization."
+ "Don't set if you want to train a model from scratch."
+ },
+ )
+ config_name: Optional[str] = field(
+ default=None,
+ metadata={"help": "Pretrained config name or path if not the same as model_name"},
+ )
+ cache_dir: Optional[str] = field(
+ default=None,
+ metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
+ )
+ token: Optional[str] = field(
+ default=None,
+ metadata={"help": "auth token for private models"},
+ )
+ use_fast_tokenizer: bool = field(
+ default=True,
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
+ )
+ model_revision: str = field(
+ default="main",
+ metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
+ )
+ trust_remote_code: bool = field(
+ default=False,
+ metadata={
+ "help": (
+ "Whether to trust the execution of code from datasets/models defined on the Hub."
+ " This option should only be set to `True` for repositories you trust and in which you have read the"
+ " code, as it will execute code present on the Hub on your local machine."
+ )
+ },
+ )
+ use_cache: bool = field(
+ default=True,
+ metadata={
+ "help": (
+ "Whether or not the model should return the last key/values attentions (not used by all models)."
+ "Only relevant if `config.is_decoder=True`."
+ )
+ },
+ )
+ low_cpu_mem_usage: bool = field(
+ default=False,
+ metadata={
+ "help": (
+ "It is an option to create the model as an empty shell, then only materialize its parameters when the pretrained weights are loaded."
+ "When set to True, it will benefit LLM loading time and RAM consumption."
+ )
+ },
+ )
+ load_meta_device: bool = field(
+ default=False,
+ metadata={
+ "help": (
+ "It is an option to load the model to the device instead of the host, so it can reduce the host RAM usage."
+ "https://huggingface.co/blog/accelerate-large-models"
+ )
+ },
+ )
+ do_image_splitting: bool = field(default=False, metadata={"help": "Whether to do image split during finetune."})
+
+
+@dataclass
+class DataArguments:
+ """
+ Arguments pertaining to what data we are going to input our model for training and eval.
+ """
+
+ dataset_name: Optional[str] = field(
+ default=None,
+ metadata={"help": "The name of the dataset to use (via the datasets library)."},
+ )
+ dataset_config_name: Optional[str] = field(
+ default=None,
+ metadata={"help": "The configuration name of the dataset to use (via the datasets library)."},
+ )
+ max_seq_length: Optional[int] = field(
+ default=512,
+ metadata={
+ "help": "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated."
+ },
+ )
+ overwrite_cache: bool = field(
+ default=False,
+ metadata={"help": "Overwrite the cached preprocessed datasets or not."},
+ )
+ max_train_samples: Optional[int] = field(
+ default=None,
+ metadata={
+ "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ },
+ )
+ max_eval_samples: Optional[int] = field(
+ default=None,
+ metadata={
+ "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
+ "value if set."
+ },
+ )
+ dataset_seed: int = field(
+ default=42,
+ metadata={
+ "help": "Seed to use in dataset processing, different seeds might yield different datasets. This seed and the seed in training arguments are not related"
+ },
+ )
+ save_last_ckpt: bool = field(
+ default=True, metadata={"help": "Whether to save checkpoint at the end of the training."}
+ )
+ input_column_names: List[str] = field(
+ default_factory=lambda: None,
+ metadata={
+ "help": "Name of the column in the dataset that optionally provides context or input for the task. By "
+ "default, 'image,query' columns are used"
+ },
+ )
+ output_column_names: List[str] = field(
+ default_factory=lambda: None,
+ metadata={
+ "help": "Name of the column in the dataset with the answer to the instruction. By default, the "
+ "'answers' column is used"
+ },
+ )
+
+
+@dataclass
+class FinetuneArguments:
+ """
+ Arguments of finetune we are going to apply on the model.
+ """
+
+ lora_rank: int = field(
+ default=8,
+ metadata={"help": "Rank parameter in the LoRA method."},
+ )
+ lora_alpha: int = field(
+ default=8,
+ metadata={"help": "Alpha parameter in the LoRA method."},
+ )
+ lora_dropout: float = field(
+ default=0.1,
+ metadata={"help": "Dropout parameter in the LoRA method."},
+ )
+ lora_target_modules: str = field(
+ default=None,
+ metadata={"help": "Target modules for the LoRA/AdaLoRA method."},
+ )
+
+
+class MyDataCollator:
+ def __init__(self, processor, max_seq_length, image_token_id):
+ self.processor = processor
+ self.image_token_id = image_token_id
+ self.max_seq_length = max_seq_length
+
+ def __call__(self, examples):
+ texts = []
+ images = []
+ keys = list(examples[0].keys())
+ if not all(key in ["image", "query", "answers"] for key in keys):
+ raise ValueError("Unsupported dataset format")
+ for example in examples:
+ image = example["image"]
+ question = example["query"]["en"]
+ answer = random.choice(example["answers"])
+ messages = [
+ {
+ "role": "user",
+ "content": [
+ {"type": "text", "text": "Answer briefly."},
+ {"type": "image"},
+ {"type": "text", "text": question},
+ ],
+ },
+ {"role": "assistant", "content": [{"type": "text", "text": answer}]},
+ ]
+ text = self.processor.apply_chat_template(messages, add_generation_prompt=False)
+ texts.append(text.strip())
+ images.append([image])
+
+ batch = self.processor(
+ text=texts,
+ images=images,
+ return_tensors="pt",
+ padding="max_length",
+ truncation=True,
+ max_length=self.max_seq_length,
+ )
+
+ labels = batch["input_ids"].clone()
+ labels[labels == self.processor.tokenizer.pad_token_id] = -100
+ labels[labels == self.image_token_id] = -100
+ batch["labels"] = labels
+
+ return batch
+
+
+def eval(processor, model, dataset, batch_size, use_lazy_mode, use_hpu_graphs, max_seq_length):
+ from tqdm import tqdm
+
+ answers_unique = []
+ generated_texts_unique = []
+
+ for i in tqdm(range(0, len(dataset), batch_size)):
+ examples = dataset[i : i + batch_size]
+ answers_unique.extend(examples["answers"])
+ images = [[im] for im in examples["image"]]
+ texts = []
+ for q in examples["query"]:
+ messages = [
+ {
+ "role": "user",
+ "content": [
+ {"type": "text", "text": "Answer briefly."},
+ {"type": "image"},
+ {"type": "text", "text": q["en"]},
+ ],
+ }
+ ]
+ text = processor.apply_chat_template(messages, add_generation_prompt=True)
+ texts.append(text.strip())
+ inputs = processor(
+ text=texts,
+ images=images,
+ return_tensors="pt",
+ padding="max_length",
+ truncation=True,
+ max_length=max_seq_length,
+ )
+ inputs = {k: v.to("hpu") for k, v in inputs.items()}
+ generated_ids = model.generate(
+ **inputs, max_new_tokens=64, ignore_eos=False, lazy_mode=use_lazy_mode, hpu_graphs=use_hpu_graphs
+ )
+ generated_texts = processor.batch_decode(
+ generated_ids[:, inputs["input_ids"].size(1) :], skip_special_tokens=True
+ )
+ generated_texts_unique.extend(generated_texts)
+ generated_texts_unique = [g.strip().strip(".") for g in generated_texts_unique]
+ anls = average_normalized_levenshtein_similarity(
+ ground_truth=answers_unique,
+ predicted_answers=generated_texts_unique,
+ )
+ return anls
+
+
+def main():
+ parser = HfArgumentParser((ModelArguments, DataArguments, GaudiTrainingArguments, FinetuneArguments))
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
+ # If we pass only one argument to the script and it's the path to a json file,
+ # let's parse it to get our arguments.
+ model_args, data_args, training_args, finetune_args = parser.parse_json_file(
+ json_file=os.path.abspath(sys.argv[1])
+ )
+ else:
+ (
+ model_args,
+ data_args,
+ training_args,
+ finetune_args,
+ ) = parser.parse_args_into_dataclasses()
+
+ # Setup logging
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ handlers=[logging.StreamHandler(sys.stdout)],
+ )
+ logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN)
+
+ if is_main_process(training_args.local_rank):
+ transformers.utils.logging.set_verbosity_info()
+ transformers.utils.logging.enable_default_handler()
+ transformers.utils.logging.enable_explicit_format()
+
+ processor = AutoProcessor.from_pretrained(
+ model_args.model_name_or_path,
+ do_image_splitting=model_args.do_image_splitting,
+ padding_side="right",
+ )
+ setattr(processor.image_processor, "pad_to_longest_edge", True)
+ config_kwargs = {
+ "cache_dir": model_args.cache_dir,
+ "revision": model_args.model_revision,
+ "trust_remote_code": True if model_args.trust_remote_code else None,
+ "use_cache": False if training_args.gradient_checkpointing else model_args.use_cache,
+ "token": model_args.token,
+ }
+ if model_args.config_name:
+ config = AutoConfig.from_pretrained(model_args.config_name, **config_kwargs)
+ elif model_args.model_name_or_path:
+ config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
+ else:
+ raise ValueError("Please provide value for model_name_or_path or config_name.")
+
+ # Load model
+ if model_args.model_name_or_path:
+ model_dtype = torch.bfloat16 if training_args.bf16 else None
+ model = AutoModelForVision2Seq.from_pretrained(
+ model_args.model_name_or_path,
+ from_tf=bool(".ckpt" in model_args.model_name_or_path),
+ config=config,
+ cache_dir=model_args.cache_dir,
+ revision=model_args.model_revision,
+ trust_remote_code=True if model_args.trust_remote_code else None,
+ torch_dtype=model_dtype,
+ low_cpu_mem_usage=model_args.low_cpu_mem_usage,
+ device_map=training_args.device.type if model_args.load_meta_device else None,
+ token=model_args.token,
+ )
+ else:
+ raise ValueError("Must provide model_name_or_path to load a pretrained CausalLM model.")
+
+ lora_config = LoraConfig(
+ r=finetune_args.lora_rank,
+ lora_alpha=finetune_args.lora_alpha,
+ lora_dropout=finetune_args.lora_dropout,
+ target_modules=finetune_args.lora_target_modules,
+ init_lora_weights="gaussian",
+ )
+ model = get_peft_model(model, lora_config)
+ model.print_trainable_parameters()
+
+ train_dataset = load_dataset(
+ data_args.dataset_name,
+ data_args.dataset_config_name,
+ cache_dir=model_args.cache_dir,
+ token=model_args.token,
+ trust_remote_code=model_args.trust_remote_code,
+ split="train",
+ )
+
+ train_dataset = train_dataset.remove_columns(
+ [
+ col
+ for col in train_dataset.column_names
+ if col not in (data_args.input_column_names + data_args.output_column_names)
+ ]
+ )
+
+ eval_dataset = load_dataset(
+ data_args.dataset_name,
+ data_args.dataset_config_name,
+ cache_dir=model_args.cache_dir,
+ token=model_args.token,
+ trust_remote_code=model_args.trust_remote_code,
+ split="test",
+ )
+
+ eval_dataset = eval_dataset.remove_columns(
+ [
+ col
+ for col in eval_dataset.column_names
+ if col not in (data_args.input_column_names + data_args.output_column_names)
+ ]
+ )
+ if hasattr(config, "image_token_id"):
+ # idefics
+ image_token_id = config.image_token_id
+ elif hasattr(config, "image_token_index"):
+ # mllama
+ image_token_id = config.image_token_index
+ else:
+ raise ValueError("Please provide value for image_token_id")
+ data_collator = MyDataCollator(processor, max_seq_length=data_args.max_seq_length, image_token_id=image_token_id)
+
+ gaudi_config = GaudiConfig()
+ gaudi_config.use_fused_adam = True
+ gaudi_config.use_fused_clip_norm = True
+
+ trainer = GaudiTrainer(
+ model=model,
+ args=training_args,
+ gaudi_config=gaudi_config,
+ data_collator=data_collator,
+ train_dataset=train_dataset,
+ eval_dataset=eval_dataset,
+ )
+
+ if training_args.do_train:
+ train_result = trainer.train()
+ trainer.save_model()
+ metrics = train_result.metrics
+ trainer.log_metrics("train", metrics)
+ trainer.save_metrics("train", metrics)
+
+ if is_main_process(training_args.local_rank):
+ processor.tokenizer.padding_side = "left"
+
+ example = eval_dataset[15]
+ model.eval()
+ model = model.merge_and_unload()
+ if model_dtype == torch.bfloat16:
+ model = model.to(torch.bfloat16)
+
+ image = example["image"]
+ query = example["query"]
+
+ messages = [
+ {
+ "role": "user",
+ "content": [
+ {"type": "text", "text": "Answer briefly."},
+ {"type": "image"},
+ {"type": "text", "text": query["en"]},
+ ],
+ }
+ ]
+ text = processor.apply_chat_template(messages, add_generation_prompt=True)
+ inputs = processor(
+ text=[text.strip()],
+ images=[image],
+ return_tensors="pt",
+ padding="max_length",
+ truncation=True,
+ max_length=data_args.max_seq_length,
+ )
+ inputs = {k: v.to("hpu") for k, v in inputs.items()}
+ generated_ids = model.generate(
+ **inputs,
+ max_new_tokens=64,
+ ignore_eos=False,
+ lazy_mode=training_args.use_lazy_mode,
+ hpu_graphs=training_args.use_hpu_graphs_for_inference,
+ )
+ generated_texts = processor.batch_decode(
+ generated_ids[:, inputs["input_ids"].size(1) :], skip_special_tokens=True
+ )
+ logger.info(f"generated: {generated_texts}")
+ if training_args.do_eval:
+ if training_args.use_hpu_graphs_for_inference:
+ from habana_frameworks.torch.hpu import wrap_in_hpu_graph
+
+ model = wrap_in_hpu_graph(model)
+
+ anls = eval(
+ processor=processor,
+ model=model,
+ dataset=eval_dataset,
+ batch_size=training_args.per_device_eval_batch_size,
+ use_lazy_mode=training_args.use_lazy_mode,
+ use_hpu_graphs=training_args.use_hpu_graphs_for_inference,
+ max_seq_length=data_args.max_seq_length,
+ )
+ eval_metrics = {"eval_accuracy": anls}
+ trainer.log_metrics("eval", eval_metrics)
+ trainer.save_metrics("eval", eval_metrics)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/image-to-text/run_pipeline.py b/examples/image-to-text/run_pipeline.py
index 9f523fc3c7..75b391ea2e 100644
--- a/examples/image-to-text/run_pipeline.py
+++ b/examples/image-to-text/run_pipeline.py
@@ -23,7 +23,7 @@
import PIL.Image
import requests
import torch
-from transformers import AutoConfig, LlavaNextProcessor, LlavaProcessor, pipeline
+from transformers import AutoConfig, AutoModelForVision2Seq, AutoProcessor, pipeline
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
@@ -36,6 +36,53 @@
logger = logging.getLogger(__name__)
+def override_print(enable):
+ import builtins as __builtin__
+
+ builtin_print = __builtin__.print
+
+ def print(*args, **kwargs):
+ force = kwargs.pop("force", False)
+ if force or enable:
+ builtin_print(*args, **kwargs)
+
+ __builtin__.print = print
+
+
+def override_logger(logger, enable):
+ logger_info = logger.info
+
+ def info(*args, **kwargs):
+ force = kwargs.pop("force", False)
+ if force or enable:
+ logger_info(*args, **kwargs)
+
+ logger.info = info
+
+
+def initialize_distributed_model(args, model, logger, model_dtype):
+ override_print(args.global_rank == 0)
+ override_logger(logger, args.global_rank == 0)
+
+ import deepspeed
+
+ logger.info(f"Initializing DeepSpeed with world size: {args.world_size}")
+ deepspeed.init_distributed(
+ dist_backend="hccl",
+ verbose=args.global_rank == 0,
+ )
+ model.eval()
+
+ ds_inference_kwargs = {"dtype": model_dtype}
+ ds_inference_kwargs["tensor_parallel"] = {"tp_size": args.world_size}
+ ds_inference_kwargs["enable_cuda_graph"] = args.use_hpu_graphs
+ ds_inference_kwargs["injection_policy"] = {}
+
+ model = deepspeed.init_inference(model, **ds_inference_kwargs).module
+
+ return model
+
+
def setup_quantization(model, args):
from neural_compressor.torch.quantization import FP8Config, convert, prepare
@@ -129,21 +176,23 @@ def main():
# set args.quant_config with env variable if it is set
args.quant_config = os.getenv("QUANT_CONFIG", "")
+
+ args.local_rank = int(os.getenv("LOCAL_RANK", "0"))
+ args.world_size = int(os.getenv("WORLD_SIZE", "0"))
+ args.global_rank = int(os.getenv("RANK", "0"))
+
os.environ.setdefault("EXPERIMENTAL_WEIGHT_SHARING", "FALSE")
adapt_transformers_to_gaudi()
model_type = AutoConfig.from_pretrained(args.model_name_or_path).model_type
- if args.image_path is None and model_type == "llava":
+ if args.image_path is None and model_type in ["llava", "idefics2", "mllama"]:
args.image_path = ["https://llava-vl.github.io/static/images/view.jpg"]
elif args.image_path is None and model_type == "llava_next":
args.image_path = [
"https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true"
]
- if args.prompt is None and model_type in ("llava", "llava_next"):
- if model_type == "llava":
- processor = LlavaProcessor.from_pretrained(args.model_name_or_path)
- elif model_type == "llava_next":
- processor = LlavaNextProcessor.from_pretrained(args.model_name_or_path)
+ if args.prompt is None and model_type in ["llava", "idefics2", "llava_next", "mllama"]:
+ processor = AutoProcessor.from_pretrained(args.model_name_or_path)
conversation = [
{
"role": "user",
@@ -181,12 +230,36 @@ def main():
htcore.hpu_set_env()
- generator = pipeline(
- "image-to-text",
- model=args.model_name_or_path,
- torch_dtype=model_dtype,
- device="hpu",
- )
+ if args.world_size > 1:
+ import deepspeed
+
+ with deepspeed.OnDevice(dtype=model_dtype, device="cpu"):
+ model = AutoModelForVision2Seq.from_pretrained(args.model_name_or_path, torch_dtype=model_dtype)
+ if model_type == "mllama":
+ model.language_model = initialize_distributed_model(args, model.language_model, logger, model_dtype)
+ else:
+ model = initialize_distributed_model(args, model, logger, model_dtype)
+ generator = pipeline(
+ "image-to-text",
+ model=model,
+ config=args.model_name_or_path,
+ tokenizer=args.model_name_or_path,
+ image_processor=args.model_name_or_path,
+ torch_dtype=model_dtype,
+ device="hpu",
+ )
+ else:
+ generator = pipeline(
+ "image-to-text",
+ model=args.model_name_or_path,
+ torch_dtype=model_dtype,
+ device="hpu",
+ )
+ if args.use_hpu_graphs:
+ from habana_frameworks.torch.hpu import wrap_in_hpu_graph
+
+ generator.model = wrap_in_hpu_graph(generator.model)
+
generate_kwargs = {
"lazy_mode": True,
"hpu_graphs": args.use_hpu_graphs,
@@ -198,15 +271,21 @@ def main():
if args.use_kv_cache:
generate_kwargs["use_cache"] = args.use_kv_cache
- if args.use_hpu_graphs:
- from habana_frameworks.torch.hpu import wrap_in_hpu_graph
-
- generator.model = wrap_in_hpu_graph(generator.model)
-
if args.quant_config:
generator.model = setup_quantization(generator.model, args)
htcore.hpu_initialize(generator.model)
+ # delete once pipeline integrate AutoProcessor as preprocess engine
+ if model_type in ["idefics2", "mllama"]:
+ from transformers.image_utils import load_image
+
+ def preprocess(self, image, prompt=None, timeout=None):
+ image = load_image(image, timeout=timeout)
+ model_inputs = processor(images=image, text=prompt, return_tensors=self.framework)
+ return model_inputs
+
+ generator.__class__.preprocess = preprocess
+
# warm up
for i in range(args.warmup):
generator(images, prompt=args.prompt, batch_size=args.batch_size, generate_kwargs=generate_kwargs)
diff --git a/examples/language-modeling/README.md b/examples/language-modeling/README.md
index 8ea0cdd554..f10e46d757 100644
--- a/examples/language-modeling/README.md
+++ b/examples/language-modeling/README.md
@@ -404,7 +404,7 @@ python3 run_lora_clm.py \
```
- Single-card finetuning of Falcon-40B:
```bash
-LOWER_LIST=ops_bf16.txt python3 run_lora_clm.py \
+PT_HPU_AUTOCAST_LOWER_PRECISION_OPS_LIST=ops_bf16.txt python3 run_lora_clm.py \
--model_name_or_path tiiuae/falcon-40b \
--dataset_name timdettmers/openassistant-guanaco \
--bf16 True \
@@ -474,39 +474,39 @@ python ../gaudi_spawn.py \
- Multi-card finetuning of Llama2-7B with FP8:
```bash
-LOWER_LIST=ops_bf16.txt python ../gaudi_spawn.py \
- --world_size 8 --use_mpi run_lora_clm.py \
- --model_name_or_path meta-llama/Llama-2-7b-hf \
- --dataset_name tatsu-lab/alpaca \
- --bf16 True \
- --output_dir ./model_lora_llama \
- --num_train_epochs 3 \
- --per_device_train_batch_size 16 \
- --gradient_accumulation_steps 1 \
- --eval_strategy "no" \
- --save_strategy "no" \
- --learning_rate 3e-4 \
- --warmup_ratio 0.03 \
- --lr_scheduler_type "constant" \
- --max_grad_norm 0.3 \
- --logging_steps 20 \
- --do_train \
- --do_eval \
- --use_habana \
- --use_lazy_mode \
- --throughput_warmup_steps 18 \
- --lora_rank=8 \
- --lora_alpha=16 \
- --lora_dropout=0.05 \
- --lora_target_modules "q_proj" "v_proj" \
- --dataset_concatenation \
- --max_seq_length 512 \
- --ddp_bucket_cap_mb 50 \
- --adam_epsilon 1e-08 \
- --validation_split_percentage 10 \
- --low_cpu_mem_usage True \
- --pipelining_fwd_bwd \
- --fp8 True
+PT_HPU_AUTOCAST_LOWER_PRECISION_OPS_LIST=ops_bf16.txt python ../gaudi_spawn.py \
+ --world_size 8 --use_mpi run_lora_clm.py \
+ --model_name_or_path meta-llama/Llama-2-7b-hf \
+ --dataset_name tatsu-lab/alpaca \
+ --bf16 True \
+ --output_dir ./model_lora_llama \
+ --num_train_epochs 3 \
+ --per_device_train_batch_size 16 \
+ --gradient_accumulation_steps 1 \
+ --eval_strategy "no" \
+ --save_strategy "no" \
+ --learning_rate 3e-4 \
+ --warmup_ratio 0.03 \
+ --lr_scheduler_type "constant" \
+ --max_grad_norm 0.3 \
+ --logging_steps 20 \
+ --do_train \
+ --do_eval \
+ --use_habana \
+ --use_lazy_mode \
+ --throughput_warmup_steps 18 \
+ --lora_rank=8 \
+ --lora_alpha=16 \
+ --lora_dropout=0.05 \
+ --lora_target_modules "q_proj" "v_proj" \
+ --dataset_concatenation \
+ --max_seq_length 512 \
+ --ddp_bucket_cap_mb 50 \
+ --adam_epsilon 1e-08 \
+ --validation_split_percentage 10 \
+ --low_cpu_mem_usage True \
+ --pipelining_fwd_bwd \
+ --fp8 True
```
- Multi-card finetuning of codegen-16B-mono:
@@ -569,7 +569,7 @@ python ../gaudi_spawn.py \
- Multi-card finetuning of Falcon-40B:
```bash
-LOWER_LIST=ops_bf16.txt python3 ../gaudi_spawn.py \
+PT_HPU_AUTOCAST_LOWER_PRECISION_OPS_LIST=ops_bf16.txt python3 ../gaudi_spawn.py \
--world_size 8 --use_mpi run_lora_clm.py \
--model_name_or_path tiiuae/falcon-40b \
--dataset_name timdettmers/openassistant-guanaco \
@@ -647,7 +647,7 @@ python3 ../gaudi_spawn.py --use_deepspeed --world_size 8 run_lora_clm.py \
- Multi-card finetuning of Llama2-70B with FSDP and LoRA:
```bash
-LOWER_LIST=ops_bf16.txt PT_HPU_LAZY_MODE=0 \
+PT_HPU_AUTOCAST_LOWER_PRECISION_OPS_LIST=ops_bf16.txt PT_HPU_LAZY_MODE=0 \
python3 ../gaudi_spawn.py --world_size 8 --use_mpi run_lora_clm.py \
--model_name_or_path meta-llama/Llama-2-70b-hf \
--dataset_name tatsu-lab/alpaca \
@@ -690,7 +690,7 @@ python3 ../gaudi_spawn.py --world_size 8 --use_mpi run_lora_clm.py \
- Falcon-180B example command saves only the LoRA parameters at end
- For inference we need to merge the pretrained model and LoRA weights
```bash
-DEEPSPEED_HPU_ZERO3_SYNC_MARK_STEP_REQUIRED=1 LOWER_LIST=ops_bf16.txt python3 ../gaudi_spawn.py \
+PT_HPU_AUTOCAST_LOWER_PRECISION_OPS_LIST=ops_bf16.txt python3 ../gaudi_spawn.py \
--world_size 8 --use_deepspeed run_lora_clm.py \
--model_name_or_path tiiuae/falcon-180B \
--dataset_name timdettmers/openassistant-guanaco \
diff --git a/examples/language-modeling/requirements.txt b/examples/language-modeling/requirements.txt
index 955398ad19..4c2256d81b 100644
--- a/examples/language-modeling/requirements.txt
+++ b/examples/language-modeling/requirements.txt
@@ -1,4 +1,3 @@
-torch >= 1.3
datasets >= 2.14.0
sentencepiece != 0.1.92
protobuf
diff --git a/examples/language-modeling/run_lora_clm.py b/examples/language-modeling/run_lora_clm.py
index 1c6e29da25..4782ed58ae 100644
--- a/examples/language-modeling/run_lora_clm.py
+++ b/examples/language-modeling/run_lora_clm.py
@@ -700,6 +700,11 @@ def main():
raise ValueError("Must provide model_name_or_path to load a pretrained CausalLM model.")
if model.config.model_type == "llama":
+ if model.generation_config.pad_token_id is None:
+ if isinstance(model.generation_config.eos_token_id, int):
+ model.generation_config.pad_token_id = model.generation_config.eos_token_id
+ elif isinstance(model.generation_config.eos_token_id, list):
+ model.generation_config.pad_token_id = model.generation_config.eos_token_id[0]
if model_args.attn_softmax_bf16:
model.generation_config.attn_softmax_bf16 = True
if model_args.use_flash_attention:
@@ -717,7 +722,10 @@ def main():
if hasattr(model.generation_config, "pad_token_id") and model.generation_config.pad_token_id is not None:
tokenizer.pad_token_id = model.generation_config.pad_token_id
if hasattr(model.generation_config, "eos_token_id") and model.generation_config.eos_token_id is not None:
- tokenizer.eos_token_id = model.generation_config.eos_token_id
+ if isinstance(model.generation_config.eos_token_id, int):
+ tokenizer.eos_token_id = model.generation_config.eos_token_id
+ elif isinstance(model.generation_config.eos_token_id, list):
+ tokenizer.eos_token_id = model.generation_config.eos_token_id[0]
if hasattr(model.generation_config, "bos_token_id") and model.generation_config.bos_token_id is not None:
tokenizer.bos_token_id = model.generation_config.bos_token_id
diff --git a/examples/stable-diffusion/training/README.md b/examples/stable-diffusion/training/README.md
index 0a74b3c558..3d35a1623e 100644
--- a/examples/stable-diffusion/training/README.md
+++ b/examples/stable-diffusion/training/README.md
@@ -31,7 +31,7 @@ Let's get our dataset. For this example, we will use some cat images: https://hu
Let's first download it locally:
-```py
+```python
from huggingface_hub import snapshot_download
local_dir = "./cat"
@@ -61,9 +61,94 @@ python textual_inversion.py \
--throughput_warmup_steps 3
```
+The following example shows how to run inference using the fine-tuned model:
+
+```python
+from optimum.habana.diffusers import GaudiStableDiffusionPipeline
+import torch
+
+model_id = "/tmp/textual_inversion_cat"
+pipe = GaudiStableDiffusionPipeline.from_pretrained(
+ model_id,
+ torch_dtype=torch.bfloat16,
+ use_habana=True,
+ use_hpu_graphs=True,
+ gaudi_config="Habana/stable-diffusion",
+)
+
+prompt = "A backpack"
+image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]
+image.save(f"cat-backpack.png")
+```
+
> Change `--resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.
-> As described in [the official paper](https://arxiv.org/abs/2208.01618), only one embedding vector is used for the placeholder token, *e.g.* `""`. However, one can also add multiple embedding vectors for the placeholder token to increase the number of fine-tuneable parameters. This can help the model to learn more complex details. To use multiple embedding vectors, you can define `--num_vectors` to a number larger than one, *e.g.*: `--num_vectors 5`. The saved textual inversion vectors will then be larger in size compared to the default case.
+> As described in [the official paper](https://arxiv.org/abs/2208.01618), only one embedding vector is used for the placeholder token, *e.g.* `""`.
+> However, one can also add multiple embedding vectors for the placeholder token to increase the number of fine-tuneable parameters.
+> This can help the model to learn more complex details. To use multiple embedding vectors, you can define `--num_vectors` to a number larger than one,
+> *e.g.*: `--num_vectors 5`. The saved textual inversion vectors will then be larger in size compared to the default case.
+
+
+## Textual Inversion XL
+
+The `textual_inversion_sdxl.py` script shows how to implement textual inversion fine-tuning on Gaudi for XL diffusion models
+such as `stabilityai/stable-diffusion-xl-base-1.0` or `cagliostrolab/animagine-xl-3.1` for example.
+
+Assuming the afforemenioned cat toy dataset has been obtained, we can launch textual inversion XL training using:
+
+```bash
+python textual_inversion_sdxl.py \
+ --pretrained_model_name_or_path stabilityai/stable-diffusion-xl-base-1.0 \
+ --train_data_dir ./cat \
+ --learnable_property object \
+ --placeholder_token "" \
+ --initializer_token toy \
+ --resolution 768 \
+ --train_batch_size 1 \
+ --gradient_accumulation_steps 4 \
+ --max_train_steps 500 \
+ --learning_rate 5.0e-04 \
+ --scale_lr \
+ --lr_scheduler constant \
+ --lr_warmup_steps 0 \
+ --output_dir /tmp/textual_inversion_cat_sdxl \
+ --save_as_full_pipeline \
+ --gaudi_config_name Habana/stable-diffusion \
+ --throughput_warmup_steps 3
+```
+
+> As described in [the official paper](https://arxiv.org/abs/2208.01618), only one embedding vector is used for the placeholder token, *e.g.* `""`.
+> However, one can also add multiple embedding vectors for the placeholder token to increase the number of fine-tuneable parameters.
+> This can help the model to learn more complex details. To use multiple embedding vectors, you can define `--num_vectors` to a number larger than one,
+> *e.g.*: `--num_vectors 5`. The saved textual inversion vectors will then be larger in size compared to the default case.
+
+The script also supports training of both text encoders of SDXL, so inference can be executed by inserting a placeholder token into one or both prompts.
+The following example shows how to run inference using the fine tuned-model with both text encoders, separately and in combination:
+
+```python
+from optimum.habana.diffusers import GaudiStableDiffusionXLPipeline
+import torch
+
+model_id = "/tmp/textual_inversion_cat_sdxl"
+pipe = GaudiStableDiffusionXLPipeline.from_pretrained(
+ model_id,
+ torch_dtype=torch.bfloat16,
+ use_habana=True,
+ use_hpu_graphs=True,
+ gaudi_config="Habana/stable-diffusion",
+)
+
+prompt = "A backpack"
+image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]
+image.save(f"cat-backpack.png")
+
+image = pipe(prompt="", prompt_2=prompt, num_inference_steps=50, guidance_scale=7.5).images[0]
+image.save(f"cat-backpack_p2.png")
+
+prompt_2 = "A colored backpack"
+image = pipe(prompt=prompt, prompt_2=prompt_2, num_inference_steps=50, guidance_scale=7.5).images[0]
+image.save(f"cat-backpack_p1and2.png")
+```
## ControlNet Training
diff --git a/examples/stable-diffusion/training/textual_inversion_sdxl.py b/examples/stable-diffusion/training/textual_inversion_sdxl.py
new file mode 100644
index 0000000000..608ee481ad
--- /dev/null
+++ b/examples/stable-diffusion/training/textual_inversion_sdxl.py
@@ -0,0 +1,1076 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import argparse
+import json
+import logging
+import math
+import os
+import random
+import shutil
+import time
+from pathlib import Path
+
+import diffusers
+import numpy as np
+import PIL
+import safetensors
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import transformers
+from accelerate.logging import get_logger
+from accelerate.utils import ProjectConfiguration
+from diffusers import (
+ DDPMScheduler,
+)
+from diffusers.optimization import get_scheduler
+from diffusers.utils import check_min_version, is_wandb_available
+from huggingface_hub import create_repo, upload_folder
+
+# TODO: remove and import from diffusers.utils when the new version of diffusers is released
+from packaging import version
+from PIL import Image
+from torch.utils.data import Dataset
+from torchvision import transforms
+from tqdm.auto import tqdm
+
+from optimum.habana import GaudiConfig
+from optimum.habana.accelerate import GaudiAccelerator
+from optimum.habana.diffusers import (
+ GaudiStableDiffusionXLPipeline,
+)
+from optimum.habana.utils import set_seed
+
+
+if is_wandb_available():
+ import wandb
+
+if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
+ PIL_INTERPOLATION = {
+ "linear": PIL.Image.Resampling.BILINEAR,
+ "bilinear": PIL.Image.Resampling.BILINEAR,
+ "bicubic": PIL.Image.Resampling.BICUBIC,
+ "lanczos": PIL.Image.Resampling.LANCZOS,
+ "nearest": PIL.Image.Resampling.NEAREST,
+ }
+else:
+ PIL_INTERPOLATION = {
+ "linear": PIL.Image.LINEAR,
+ "bilinear": PIL.Image.BILINEAR,
+ "bicubic": PIL.Image.BICUBIC,
+ "lanczos": PIL.Image.LANCZOS,
+ "nearest": PIL.Image.NEAREST,
+ }
+# ------------------------------------------------------------------------------
+
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.23.1")
+
+logger = get_logger(__name__)
+
+
+def save_model_card(repo_id: str, images=None, base_model=str, repo_folder=None):
+ img_str = ""
+ for i, image in enumerate(images):
+ image.save(os.path.join(repo_folder, f"image_{i}.png"))
+ img_str += f"\n"
+
+ yaml = f"""
+---
+license: creativeml-openrail-m
+base_model: {base_model}
+tags:
+- stable-diffusion
+- stable-diffusion-diffusers
+- text-to-image
+- diffusers
+- textual_inversion
+inference: true
+---
+ """
+ model_card = f"""
+# Textual inversion text2image fine-tuning - {repo_id}
+These are textual inversion adaption weights for {base_model}. You can find some example images in the following. \n
+{img_str}
+"""
+ with open(os.path.join(repo_folder, "README.md"), "w") as f:
+ f.write(yaml + model_card)
+
+
+def log_validation(
+ args,
+ accelerator,
+ weight_dtype,
+ epoch,
+ pipeline,
+ is_final_validation=False,
+):
+ logger.info(
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
+ f" {args.validation_prompt}."
+ )
+
+ # run inference
+ pipeline.set_progress_bar_config(disable=True)
+ generator = None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed)
+ images = []
+ for _ in range(args.num_validation_images):
+ image = pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
+ images.append(image)
+
+ tracker_key = "test" if is_final_validation else "validation"
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images(tracker_key, np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ tracker_key: [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images)
+ ]
+ }
+ )
+
+ return images
+
+
+def save_progress(text_encoder, placeholder_token_ids, accelerator, args, save_path, safe_serialization=True):
+ logger.info("Saving embeddings")
+ learned_embeds = (
+ accelerator.unwrap_model(text_encoder)
+ .get_input_embeddings()
+ .weight[min(placeholder_token_ids) : max(placeholder_token_ids) + 1]
+ )
+ learned_embeds_dict = {args.placeholder_token: learned_embeds.detach().cpu()}
+
+ if safe_serialization:
+ safetensors.torch.save_file(learned_embeds_dict, save_path, metadata={"format": "pt"})
+ else:
+ torch.save(learned_embeds_dict, save_path)
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--save_steps",
+ type=int,
+ default=500,
+ help="Save learned_embeds.bin every X updates steps.",
+ )
+ parser.add_argument(
+ "--save_as_full_pipeline",
+ action="store_true",
+ help="Save the complete stable diffusion pipeline.",
+ )
+ parser.add_argument(
+ "--num_vectors",
+ type=int,
+ default=1,
+ help="How many textual inversion vectors shall be used to learn the concept.",
+ )
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--variant",
+ type=str,
+ default=None,
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
+ )
+ parser.add_argument(
+ "--train_data_dir", type=str, default=None, required=True, help="A folder containing the training data."
+ )
+ parser.add_argument(
+ "--placeholder_token",
+ type=str,
+ default=None,
+ required=True,
+ help="A token to use as a placeholder for the concept.",
+ )
+ parser.add_argument(
+ "--initializer_token", type=str, default=None, required=True, help="A token to use as initializer word."
+ )
+ parser.add_argument("--learnable_property", type=str, default="object", help="Choose between 'object' and 'style'")
+ parser.add_argument("--repeats", type=int, default=100, help="How many times to repeat the training data.")
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="text-inversion-model",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=1024,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution."
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=100)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--lr_num_cycles",
+ type=int,
+ default=1,
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
+ )
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--bf16",
+ action="store_true",
+ default=False,
+ help=("Whether to use bf16 mixed precision."),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument(
+ "--validation_prompt",
+ type=str,
+ default=None,
+ help="A prompt that is used during validation to verify that the model is learning.",
+ )
+ parser.add_argument(
+ "--num_validation_images",
+ type=int,
+ default=4,
+ help="Number of images that should be generated during validation with `validation_prompt`.",
+ )
+ parser.add_argument(
+ "--validation_steps",
+ type=int,
+ default=100,
+ help=(
+ "Run validation every X steps. Validation consists of running the prompt"
+ " `args.validation_prompt` multiple times: `args.num_validation_images`"
+ " and logging the images."
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--scheduler",
+ default=None,
+ choices=["euler_discrete", "euler_ancestral_discrete", "ddim", "ddpm"],
+ type=str,
+ help="Name of scheduler",
+ )
+ parser.add_argument(
+ "--gaudi_config_name",
+ type=str,
+ default=None,
+ help="Local path to the Gaudi configuration file or its name on the Hugging Face Hub.",
+ )
+ parser.add_argument(
+ "--throughput_warmup_steps",
+ type=int,
+ default=0,
+ help=(
+ "Number of steps to ignore for throughput calculation. For example, with throughput_warmup_steps=N, the"
+ " first N steps will not be considered in the calculation of the throughput. This is especially useful in"
+ " lazy mode."
+ ),
+ )
+
+ args = parser.parse_args()
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ if args.train_data_dir is None:
+ raise ValueError("You must specify a train data directory.")
+
+ return args
+
+
+imagenet_templates_small = [
+ "a photo of a {}",
+ "a rendering of a {}",
+ "a cropped photo of the {}",
+ "the photo of a {}",
+ "a photo of a clean {}",
+ "a photo of a dirty {}",
+ "a dark photo of the {}",
+ "a photo of my {}",
+ "a photo of the cool {}",
+ "a close-up photo of a {}",
+ "a bright photo of the {}",
+ "a cropped photo of a {}",
+ "a photo of the {}",
+ "a good photo of the {}",
+ "a photo of one {}",
+ "a close-up photo of the {}",
+ "a rendition of the {}",
+ "a photo of the clean {}",
+ "a rendition of a {}",
+ "a photo of a nice {}",
+ "a good photo of a {}",
+ "a photo of the nice {}",
+ "a photo of the small {}",
+ "a photo of the weird {}",
+ "a photo of the large {}",
+ "a photo of a cool {}",
+ "a photo of a small {}",
+]
+
+imagenet_style_templates_small = [
+ "a painting in the style of {}",
+ "a rendering in the style of {}",
+ "a cropped painting in the style of {}",
+ "the painting in the style of {}",
+ "a clean painting in the style of {}",
+ "a dirty painting in the style of {}",
+ "a dark painting in the style of {}",
+ "a picture in the style of {}",
+ "a cool painting in the style of {}",
+ "a close-up painting in the style of {}",
+ "a bright painting in the style of {}",
+ "a cropped painting in the style of {}",
+ "a good painting in the style of {}",
+ "a close-up painting in the style of {}",
+ "a rendition in the style of {}",
+ "a nice painting in the style of {}",
+ "a small painting in the style of {}",
+ "a weird painting in the style of {}",
+ "a large painting in the style of {}",
+]
+
+
+# check: shouldn't default size be 1024 for XL?
+class TextualInversionDataset(Dataset):
+ def __init__(
+ self,
+ data_root,
+ tokenizer_1,
+ tokenizer_2,
+ learnable_property="object", # [object, style]
+ size=512,
+ repeats=100,
+ interpolation="bicubic",
+ flip_p=0.5,
+ set="train",
+ placeholder_token="*",
+ center_crop=False,
+ ):
+ self.data_root = data_root
+ self.tokenizer_1 = tokenizer_1
+ self.tokenizer_2 = tokenizer_2
+ self.learnable_property = learnable_property
+ self.size = size
+ self.placeholder_token = placeholder_token
+ self.center_crop = center_crop
+ self.flip_p = flip_p
+
+ self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]
+
+ self.num_images = len(self.image_paths)
+ self._length = self.num_images
+
+ if set == "train":
+ self._length = self.num_images * repeats
+
+ self.interpolation = {
+ "linear": PIL_INTERPOLATION["linear"],
+ "bilinear": PIL_INTERPOLATION["bilinear"],
+ "bicubic": PIL_INTERPOLATION["bicubic"],
+ "lanczos": PIL_INTERPOLATION["lanczos"],
+ }[interpolation]
+
+ self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small
+ self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)
+ self.crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size)
+
+ def __len__(self):
+ return self._length
+
+ def __getitem__(self, i):
+ example = {}
+ image = Image.open(self.image_paths[i % self.num_images])
+
+ if not image.mode == "RGB":
+ image = image.convert("RGB")
+
+ placeholder_string = self.placeholder_token
+ text = random.choice(self.templates).format(placeholder_string)
+
+ example["original_size"] = (image.height, image.width)
+
+ if self.center_crop:
+ y1 = max(0, int(round((image.height - self.size) / 2.0)))
+ x1 = max(0, int(round((image.width - self.size) / 2.0)))
+ image = self.crop(image)
+ else:
+ y1, x1, h, w = self.crop.get_params(image, (self.size, self.size))
+ image = transforms.functional.crop(image, y1, x1, h, w)
+
+ example["crop_top_left"] = (y1, x1)
+
+ example["input_ids_1"] = self.tokenizer_1(
+ text,
+ padding="max_length",
+ truncation=True,
+ max_length=self.tokenizer_1.model_max_length,
+ return_tensors="pt",
+ ).input_ids[0]
+
+ example["input_ids_2"] = self.tokenizer_2(
+ text,
+ padding="max_length",
+ truncation=True,
+ max_length=self.tokenizer_2.model_max_length,
+ return_tensors="pt",
+ ).input_ids[0]
+
+ # default to score-sde preprocessing
+ img = np.array(image).astype(np.uint8)
+
+ image = Image.fromarray(img)
+ image = image.resize((self.size, self.size), resample=self.interpolation)
+
+ image = self.flip_transform(image)
+ image = np.array(image).astype(np.uint8)
+ image = (image / 127.5 - 1.0).astype(np.float32)
+
+ example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1)
+ return example
+
+
+def main():
+ args = parse_args()
+ logging_dir = os.path.join(args.output_dir, args.logging_dir)
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+
+ gaudi_config = GaudiConfig.from_pretrained(args.gaudi_config_name)
+
+ accelerator = GaudiAccelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision="bf16" if gaudi_config.use_torch_autocast or args.bf16 else "no",
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ force_autocast=gaudi_config.use_torch_autocast or args.bf16,
+ )
+
+ if args.report_to == "wandb":
+ if not is_wandb_available():
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ import habana_frameworks.torch.core as htcore
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Load pipeline components
+ pipeline = GaudiStableDiffusionXLPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ revision=args.revision,
+ variant=args.variant,
+ use_habana=True,
+ use_hpu_graphs=True,
+ gaudi_config=args.gaudi_config_name,
+ )
+ text_encoder_1 = pipeline.text_encoder.to(accelerator.device)
+ text_encoder_2 = pipeline.text_encoder_2.to(accelerator.device)
+ tokenizer_1 = pipeline.tokenizer
+ tokenizer_2 = pipeline.tokenizer_2
+ unet = pipeline.unet
+ vae = pipeline.vae
+
+ # Load scheduler for training
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+
+ # Add the placeholder token in tokenizers
+ placeholder_tokens = [args.placeholder_token]
+
+ if args.num_vectors < 1:
+ raise ValueError(f"--num_vectors has to be larger or equal to 1, but is {args.num_vectors}")
+
+ # add dummy tokens for multi-vector
+ additional_tokens = []
+ for i in range(1, args.num_vectors):
+ additional_tokens.append(f"{args.placeholder_token}_{i}")
+ placeholder_tokens += additional_tokens
+
+ num_added_tokens = tokenizer_1.add_tokens(placeholder_tokens)
+ if num_added_tokens != args.num_vectors:
+ raise ValueError(
+ f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different"
+ " `placeholder_token` that is not already in the tokenizer."
+ )
+
+ num_added_tokens = tokenizer_2.add_tokens(placeholder_tokens)
+ if num_added_tokens != args.num_vectors:
+ raise ValueError(
+ f"The 2nd tokenizer already contains the token {args.placeholder_token}. Please pass a different"
+ " `placeholder_token` that is not already in the tokenizer."
+ )
+
+ # Convert the initializer_token, placeholder_token to ids
+ token_ids = tokenizer_1.encode(args.initializer_token, add_special_tokens=False)
+ token_ids_2 = tokenizer_2.encode(args.initializer_token, add_special_tokens=False)
+
+ # Check if initializer_token is a single token or a sequence of tokens
+ if len(token_ids) > 1 or len(token_ids_2) > 1:
+ raise ValueError("The initializer token must be a single token.")
+
+ initializer_token_id = token_ids[0]
+ initializer_token_id_2 = token_ids_2[0]
+ placeholder_token_ids = tokenizer_1.convert_tokens_to_ids(placeholder_tokens)
+ placeholder_token_ids_2 = tokenizer_2.convert_tokens_to_ids(placeholder_tokens)
+
+ # Resize the token embeddings as we are adding new special tokens to the tokenizer
+ text_encoder_1.resize_token_embeddings(len(tokenizer_1))
+ text_encoder_2.resize_token_embeddings(len(tokenizer_2))
+
+ # Initialise the newly added placeholder token with the embeddings of the initializer token
+ token_embeds = text_encoder_1.get_input_embeddings().weight.data
+ token_embeds_2 = text_encoder_2.get_input_embeddings().weight.data
+ with torch.no_grad():
+ for token_id in placeholder_token_ids:
+ token_embeds[token_id] = token_embeds[initializer_token_id].clone()
+ for token_id in placeholder_token_ids_2:
+ token_embeds_2[token_id] = token_embeds_2[initializer_token_id_2].clone()
+
+ # Freeze vae and unet
+ vae.requires_grad_(False)
+ unet.requires_grad_(False)
+
+ # Freeze all parameters except for the token embeddings in text encoder
+ text_encoder_1.text_model.encoder.requires_grad_(False)
+ text_encoder_1.text_model.final_layer_norm.requires_grad_(False)
+ text_encoder_1.text_model.embeddings.position_embedding.requires_grad_(False)
+ text_encoder_2.text_model.encoder.requires_grad_(False)
+ text_encoder_2.text_model.final_layer_norm.requires_grad_(False)
+ text_encoder_2.text_model.embeddings.position_embedding.requires_grad_(False)
+
+ if args.gradient_checkpointing:
+ text_encoder_1.gradient_checkpointing_enable()
+ text_encoder_2.gradient_checkpointing_enable()
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Initialize the optimizer
+ if gaudi_config.use_fused_adam:
+ from habana_frameworks.torch.hpex.optimizers import FusedAdamW
+
+ optimizer_class = FusedAdamW
+ else:
+ optimizer_class = torch.optim.AdamW
+
+ optimizer = optimizer_class(
+ # only optimize the embeddings
+ [
+ text_encoder_1.text_model.embeddings.token_embedding.weight,
+ text_encoder_2.text_model.embeddings.token_embedding.weight,
+ ],
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ placeholder_token = " ".join(tokenizer_1.convert_ids_to_tokens(placeholder_token_ids))
+
+ # Dataset and DataLoaders creation:
+ train_dataset = TextualInversionDataset(
+ data_root=args.train_data_dir,
+ tokenizer_1=tokenizer_1,
+ tokenizer_2=tokenizer_2,
+ size=args.resolution,
+ placeholder_token=placeholder_token,
+ repeats=args.repeats,
+ learnable_property=args.learnable_property,
+ center_crop=args.center_crop,
+ set="train",
+ )
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers
+ )
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
+ num_cycles=args.lr_num_cycles,
+ )
+
+ text_encoder_1.train()
+ text_encoder_2.train()
+ # Prepare everything with our `accelerator`.
+ text_encoder_1, text_encoder_2, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ text_encoder_1, text_encoder_2, optimizer, train_dataloader, lr_scheduler
+ )
+
+ # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision
+ # as these weights are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if gaudi_config.use_torch_autocast or args.bf16:
+ weight_dtype = torch.bfloat16
+
+ # Move vae and unet and text_encoder_2 to device and cast to weight_dtype
+ unet.to(accelerator.device, dtype=weight_dtype)
+ vae.to(accelerator.device, dtype=weight_dtype)
+ text_encoder_2.to(accelerator.device, dtype=weight_dtype)
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ accelerator.init_trackers("textual_inversion", config=vars(args))
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ # keep original embeddings as reference
+ orig_embeds_params = accelerator.unwrap_model(text_encoder_1).get_input_embeddings().weight.data.clone()
+ orig_embeds_params_2 = accelerator.unwrap_model(text_encoder_2).get_input_embeddings().weight.data.clone()
+
+ t0 = None
+ # pipeline = None
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ text_encoder_1.train()
+ text_encoder_2.train()
+ for step, batch in enumerate(train_dataloader):
+ if t0 is None and global_step == args.throughput_warmup_steps:
+ t0 = time.perf_counter()
+
+ with accelerator.accumulate([text_encoder_1, text_encoder_2]):
+ # Convert images to latent space
+ latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample().detach()
+ latents = latents * vae.config.scaling_factor
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(latents)
+ bsz = latents.shape[0]
+ # Sample a random timestep for each image
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
+ timesteps = timesteps.long()
+
+ # Add noise to the latents according to the noise magnitude at each timestep
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
+
+ # Get the text embedding for conditioning
+ encoder_hidden_states_1 = (
+ text_encoder_1(batch["input_ids_1"], output_hidden_states=True)
+ .hidden_states[-2]
+ .to(dtype=weight_dtype)
+ )
+ encoder_output_2 = text_encoder_2(batch["input_ids_2"], output_hidden_states=True)
+ encoder_hidden_states_2 = encoder_output_2.hidden_states[-2].to(dtype=weight_dtype)
+
+ original_size = [
+ (batch["original_size"][0][i].item(), batch["original_size"][1][i].item())
+ for i in range(args.train_batch_size)
+ ]
+ crop_top_left = [
+ (batch["crop_top_left"][0][i].item(), batch["crop_top_left"][1][i].item())
+ for i in range(args.train_batch_size)
+ ]
+ target_size = (args.resolution, args.resolution)
+ add_time_ids = torch.cat(
+ [
+ torch.tensor(original_size[i] + crop_top_left[i] + target_size)
+ for i in range(args.train_batch_size)
+ ]
+ ).to(accelerator.device, dtype=weight_dtype)
+ added_cond_kwargs = {"text_embeds": encoder_output_2[0], "time_ids": add_time_ids}
+ encoder_hidden_states = torch.cat([encoder_hidden_states_1, encoder_hidden_states_2], dim=-1)
+
+ # Predict the noise residual
+ model_pred = unet(
+ noisy_latents, timesteps, encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
+ ).sample
+
+ # Get the target for loss depending on the prediction type
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+
+ accelerator.backward(loss)
+ htcore.mark_step()
+
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad(set_to_none=True)
+ htcore.mark_step()
+
+ # Let's make sure we don't update any embedding weights besides the newly added token
+ index_no_updates = torch.ones((len(tokenizer_1),), dtype=torch.bool)
+ index_no_updates[min(placeholder_token_ids) : max(placeholder_token_ids) + 1] = False
+ index_no_updates_2 = torch.ones((len(tokenizer_2),), dtype=torch.bool)
+ index_no_updates_2[min(placeholder_token_ids_2) : max(placeholder_token_ids_2) + 1] = False
+
+ with torch.no_grad():
+ accelerator.unwrap_model(text_encoder_1).get_input_embeddings().weight[index_no_updates] = (
+ orig_embeds_params[index_no_updates]
+ )
+ accelerator.unwrap_model(text_encoder_2).get_input_embeddings().weight[index_no_updates_2] = (
+ orig_embeds_params_2[index_no_updates_2]
+ )
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ images = []
+ progress_bar.update(1)
+ global_step += 1
+ if global_step % args.save_steps == 0:
+ weight_name = f"learned_embeds-steps-{global_step}.safetensors"
+ save_path = os.path.join(args.output_dir, weight_name)
+ save_progress(
+ text_encoder_1,
+ placeholder_token_ids,
+ accelerator,
+ args,
+ save_path,
+ safe_serialization=True,
+ )
+ weight_name = f"learned_embeds_2-steps-{global_step}.safetensors"
+ save_path = os.path.join(args.output_dir, weight_name)
+ save_progress(
+ text_encoder_2,
+ placeholder_token_ids_2,
+ accelerator,
+ args,
+ save_path,
+ safe_serialization=True,
+ )
+
+ if accelerator.is_main_process:
+ if global_step % args.checkpointing_steps == 0:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ if args.validation_prompt is not None and global_step % args.validation_steps == 0:
+ images = log_validation(
+ args,
+ accelerator,
+ weight_dtype,
+ epoch,
+ pipeline,
+ )
+
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ duration = time.perf_counter() - t0
+ throughput = args.max_train_steps * total_batch_size / duration
+
+ # Create the pipeline using the trained modules and save it.
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ if args.validation_prompt:
+ images = log_validation(
+ args,
+ accelerator,
+ weight_dtype,
+ epoch,
+ pipeline,
+ is_final_validation=True,
+ )
+
+ logger.info(f"Throughput = {throughput} samples/s")
+ logger.info(f"Train runtime = {duration} seconds")
+ metrics = {
+ "train_samples_per_second": throughput,
+ "train_runtime": duration,
+ }
+ with open(f"{args.output_dir}/speed_metrics.json", mode="w") as file:
+ json.dump(metrics, file)
+
+ if args.push_to_hub and not args.save_as_full_pipeline:
+ logger.warning("Enabling full model saving because --push_to_hub=True was specified.")
+ save_full_model = True
+ else:
+ save_full_model = args.save_as_full_pipeline
+ if save_full_model:
+ pipeline.save_pretrained(args.output_dir)
+
+ # Save the newly trained embeddings
+ weight_name = "learned_embeds.safetensors"
+ save_path = os.path.join(args.output_dir, weight_name)
+ save_progress(
+ text_encoder_1,
+ placeholder_token_ids,
+ accelerator,
+ args,
+ save_path,
+ safe_serialization=True,
+ )
+ weight_name = "learned_embeds_2.safetensors"
+ save_path = os.path.join(args.output_dir, weight_name)
+ save_progress(
+ text_encoder_2,
+ placeholder_token_ids_2,
+ accelerator,
+ args,
+ save_path,
+ safe_serialization=True,
+ )
+
+ if args.push_to_hub:
+ save_model_card(
+ repo_id,
+ images=images,
+ base_model=args.pretrained_model_name_or_path,
+ repo_folder=args.output_dir,
+ )
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/stable-diffusion/unconditional_image_generation.py b/examples/stable-diffusion/unconditional_image_generation.py
index baca71b6ba..36e35ff90f 100644
--- a/examples/stable-diffusion/unconditional_image_generation.py
+++ b/examples/stable-diffusion/unconditional_image_generation.py
@@ -79,6 +79,12 @@ def main():
default="/tmp/",
help="Where to save the generated images. The default is DDPMScheduler.",
)
+ parser.add_argument(
+ "--throughput_warmup_steps",
+ type=int,
+ default=3,
+ help="Number of steps to ignore for throughput calculation.",
+ )
args = parser.parse_args()
model_name = args.model_name_or_path
@@ -100,8 +106,10 @@ def main():
"gaudi_config": gaudi_config,
}
+ kwargs_call = {"throughput_warmup_steps": args.throughput_warmup_steps}
+
pipeline = GaudiDDPMPipeline.from_pretrained(model_name, **kwargs)
- output = pipeline(batch_size=args.batch_size, num_inference_steps=args.num_inference_steps)
+ output = pipeline(batch_size=args.batch_size, num_inference_steps=args.num_inference_steps, **kwargs_call)
if args.output_dir:
logger.info(f"Generating outputs to {args.output_dir}")
diff --git a/examples/text-generation/run_generation.py b/examples/text-generation/run_generation.py
index 74e4f600d7..ee548bc1c8 100755
--- a/examples/text-generation/run_generation.py
+++ b/examples/text-generation/run_generation.py
@@ -561,12 +561,16 @@ def rounder(x):
print()
print("Input/outputs:")
+ all_inputs = []
+ all_outputs = []
for i, input_sentence in enumerate(zip(input_sentences)):
print(f"input {i+1}: {input_sentence}")
+ all_inputs.append(input_sentence)
for j, output in enumerate(
zip(generated[args.num_return_sequences * i : args.num_return_sequences * (i + 1)])
):
print(f"output {j+1}: {output}")
+ all_outputs.append(output)
print()
# Store results if necessary
@@ -576,7 +580,8 @@ def rounder(x):
results = {
"throughput": throughput,
- "output": output,
+ "input": all_inputs,
+ "output": all_outputs,
}
with (output_dir / "results.json").open("w", encoding="utf-8") as f:
json.dump(results, f, ensure_ascii=False, indent=4)
diff --git a/examples/text-generation/run_lm_eval.py b/examples/text-generation/run_lm_eval.py
index 3299cadcbe..c827291416 100644
--- a/examples/text-generation/run_lm_eval.py
+++ b/examples/text-generation/run_lm_eval.py
@@ -195,6 +195,14 @@ def main():
args = setup_lm_eval_parser()
model, _, tokenizer, generation_config = initialize_model(args, logger)
+ if args.trust_remote_code:
+ # trust_remote_code fix was introduced in lm_eval 0.4.3
+ # https://github.com/EleutherAI/lm-evaluation-harness/pull/1998/files
+ # We need to cherry-pick the fix manually untill we upgrade (SW-190418)
+ import datasets
+
+ datasets.config.HF_DATASETS_TRUST_REMOTE_CODE = True
+
lm_tasks = lm_eval.tasks.get_task_dict(args.tasks)
with torch.no_grad():
lm = HabanaModelAdapter(tokenizer, model, args, generation_config)
diff --git a/examples/text-generation/utils.py b/examples/text-generation/utils.py
index 3e21f55445..b4027fec5e 100644
--- a/examples/text-generation/utils.py
+++ b/examples/text-generation/utils.py
@@ -180,23 +180,24 @@ def get_torch_compiled_model(model, logger):
# for gpt_bigcode, mpt, bloom, gpt2 model_type
if hasattr(model, 'transformer'):
model.transformer = torch.compile(
- model.transformer, backend="hpu_backend", options={"keep_input_mutations": True}
+ model.transformer, backend="hpu_backend"
)
# for gpt_neox
elif hasattr(model, 'gpt_neox'):
model.gpt_neox = torch.compile(
- model.gpt_neox, backend="hpu_backend", options={"keep_input_mutations": True}
+ model.gpt_neox, backend="hpu_backend"
)
# for llama, mistral, mixtral, qwen2
elif hasattr(model, 'model'):
model.model = torch.compile(
- model.model, backend="hpu_backend", options={"keep_input_mutations": True}
+ model.model, backend="hpu_backend"
)
else:
logger.warning(
"in low performance case, please explicitly specify a module you want wrap with `torch.compile`"
)
- model = torch.compile(model, backend="hpu_backend", options={"keep_input_mutations": True})
+ model = torch.compile(model, backend="hpu_backend")
+
return model
@@ -281,6 +282,9 @@ def setup_model(args, model_dtype, model_kwargs, logger):
original_model=org_model,
**model_kwargs,
)
+ # TODO: This will be removed in v1.19 Synapse release
+ # the loaded model should have the same dtype as original_model
+ model = model.to(model_kwargs["torch_dtype"])
else:
if args.assistant_model is not None:
assistant_model = AutoModelForCausalLM.from_pretrained(
diff --git a/examples/text-to-speech/README.md b/examples/text-to-speech/README.md
index 5b98b30493..a1e089f55e 100644
--- a/examples/text-to-speech/README.md
+++ b/examples/text-to-speech/README.md
@@ -36,4 +36,5 @@ python3 run_pipeline.py \
```
Models that have been validated:
- [microsoft/speecht5_tts](https://huggingface.co/microsoft/speecht5_tts)
+ - [facebook/hf-seamless-m4t-medium](https://huggingface.co/facebook/hf-seamless-m4t-medium)
- [facebook/mms-tts-eng](https://huggingface.co/facebook/mms-tts-eng)
diff --git a/optimum/habana/diffusers/pipelines/ddpm/pipeline_ddpm.py b/optimum/habana/diffusers/pipelines/ddpm/pipeline_ddpm.py
index 7b3ea5afdb..65a7df7e2d 100644
--- a/optimum/habana/diffusers/pipelines/ddpm/pipeline_ddpm.py
+++ b/optimum/habana/diffusers/pipelines/ddpm/pipeline_ddpm.py
@@ -29,6 +29,8 @@
from optimum.habana.transformers.gaudi_configuration import GaudiConfig
from optimum.utils import logging
+from ....utils import speed_metrics
+
logger = logging.get_logger(__name__)
@@ -149,8 +151,14 @@ def __call__(
if self.use_habana:
self.unet = self.unet.to(self._device)
+ throughput_warmup_steps = kwargs.get("throughput_warmup_steps", 3)
+
start_time = time.time()
+ time_after_warmup = start_time
for i in self.progress_bar(num_inference_steps):
+ if i == throughput_warmup_steps:
+ time_after_warmup = time.time()
+
timestep = timesteps[0]
timesteps = torch.roll(timesteps, shifts=-1, dims=0)
@@ -172,7 +180,16 @@ def __call__(
image = image.cpu().permute(0, 2, 3, 1).numpy()
if output_type == "pil":
image = self.numpy_to_pil(image)
- end_time = time.time()
+
+ speed_metrics_prefix = "generation"
+ speed_measures = speed_metrics(
+ split=speed_metrics_prefix,
+ start_time=start_time,
+ num_samples=batch_size,
+ num_steps=batch_size * len(num_inference_steps),
+ start_time_after_warmup=time_after_warmup,
+ )
+ logger.info(f"Speed metrics: {speed_measures}")
# Offload all models
self.maybe_free_model_hooks()
@@ -180,5 +197,5 @@ def __call__(
if not return_dict:
return (image,)
- throughput = (end_time - start_time) / batch_size
+ throughput = speed_measures["generation_samples_per_second"]
return GaudiDDPMPipelineOutput(images=image, throughput=throughput)
diff --git a/optimum/habana/transformers/gaudi_configuration.py b/optimum/habana/transformers/gaudi_configuration.py
index 76638d8e95..faeceb8be8 100644
--- a/optimum/habana/transformers/gaudi_configuration.py
+++ b/optimum/habana/transformers/gaudi_configuration.py
@@ -93,5 +93,5 @@ def declare_autocast_bf16_fp32_ops(self):
autocast_bf16_filename,
autocast_fp32_filename,
)
- os.environ["LOWER_LIST"] = autocast_bf16_filename
- os.environ["FP32_LIST"] = autocast_fp32_filename
+ os.environ["PT_HPU_AUTOCAST_LOWER_PRECISION_OPS_LIST"] = autocast_bf16_filename
+ os.environ["PT_HPU_AUTOCAST_FP32_OPS_LIST"] = autocast_fp32_filename
diff --git a/optimum/habana/transformers/generation/stopping_criteria.py b/optimum/habana/transformers/generation/stopping_criteria.py
index 69325ab7b3..844ffa50f2 100644
--- a/optimum/habana/transformers/generation/stopping_criteria.py
+++ b/optimum/habana/transformers/generation/stopping_criteria.py
@@ -38,8 +38,10 @@ def gaudi_MaxLengthCriteria_call(
) -> Union[torch.BoolTensor, bool]:
token_idx = kwargs.get("token_idx", None)
if token_idx is not None:
- assert not kwargs["needs_tensor_output"]
- return token_idx >= self.max_length
+ if not kwargs["needs_tensor_output"]:
+ return token_idx >= self.max_length
+ else:
+ return create_return_const_tensor(input_ids, token_idx >= self.max_length)
else:
cur_len = input_ids.shape[-1]
is_done = cur_len >= self.max_length
@@ -68,7 +70,6 @@ def gaudi_EosTokenCriteria_call(
self.eos_token_id = self.eos_token_id.to(input_ids.device)
token_idx = kwargs.get("token_idx", None)
if token_idx is not None:
- assert not kwargs["needs_tensor_output"]
is_done = torch.isin(input_ids[:, token_idx - 1], self.eos_token_id)
else:
is_done = torch.isin(input_ids[:, -1], self.eos_token_id)
@@ -78,19 +79,15 @@ def gaudi_EosTokenCriteria_call(
return torch.all(is_done).item()
-def needs_tensor_output(token_idx, ignore_eos, eos_token_id) -> bool:
- if token_idx is None:
- return not ignore_eos and eos_token_id is not None
- else:
- # token_idx is present, so we have static shapes, so using single boolean
- return False
+def needs_tensor_output(ignore_eos, eos_token_id) -> bool:
+ return not ignore_eos and eos_token_id is not None
def gaudi_StoppingCriteriaList_call(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
) -> Union[torch.BoolTensor, bool]:
kwargs["needs_tensor_output"] = needs_tensor_output(
- kwargs.get("token_idx", None), kwargs.get("ignore_eos", True), kwargs.get("eos_token_id", None)
+ kwargs.get("ignore_eos", True), kwargs.get("eos_token_id", None)
)
is_done = (
torch.full((input_ids.shape[0],), 0, device=input_ids.device, dtype=torch.int8)
diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py
index 2025b387d6..f088974f3f 100644
--- a/optimum/habana/transformers/generation/utils.py
+++ b/optimum/habana/transformers/generation/utils.py
@@ -104,8 +104,12 @@
"stablelm",
"mamba",
"deci",
+ "cohere",
"qwen2_moe",
+ "xglm",
"whisper",
+ "idefics2",
+ "mllama",
]
@@ -237,14 +241,20 @@ def _prepare_decoder_input_ids_for_generation(
if token_idx is None:
decoder_input_ids = torch.cat([decoder_start_token_id, decoder_input_ids], dim=-1)
else:
- max_length = max_new_tokens + 2 if max_new_tokens is not None else self.generation_config.max_length
+ decoder_input_ids_len = decoder_input_ids.shape[-1]
+ max_length = (
+ max_new_tokens + decoder_input_ids_len + 1
+ if max_new_tokens is not None
+ else self.generation_config.max_length
+ )
if max_length != decoder_start_token_id.shape[-1]:
decoder_start_token_id = torch.nn.functional.pad(
decoder_start_token_id,
(0, max_length - decoder_start_token_id.shape[-1]),
value=pad_token_id,
)
- decoder_input_ids = decoder_start_token_id.index_copy(1, token_idx, decoder_input_ids)
+ decoder_start_token_id[:, 1 : 1 + decoder_input_ids_len, ...] = decoder_input_ids
+ decoder_input_ids = decoder_start_token_id
token_idx.add_(1)
if "decoder_attention_mask" in model_kwargs:
decoder_attention_mask = model_kwargs["decoder_attention_mask"]
@@ -321,11 +331,13 @@ def _expand_dict_for_generation(dict_to_expand):
def _pad_past_key_values(self, model_kwargs):
pad_amount = model_kwargs.get("kv_cache_pad_len", 0)
+ kv_cache_len = model_kwargs.get("kv_cache_len", 0)
if model_kwargs["past_key_values"]:
if model_kwargs.get("mqa_model", False):
for i in range(len(model_kwargs["past_key_values"])): # layer
- if torch.is_tensor(
- model_kwargs["past_key_values"][i]
+ if (
+ torch.is_tensor(model_kwargs["past_key_values"][i])
+ and model_kwargs["past_key_values"][i].shape[-2] == kv_cache_len - pad_amount
): # tensor(batch_size, kv_cache_len, n_heads * head_dim * 2) k and v stacked
model_kwargs["past_key_values"][i] = torch.nn.functional.pad(
model_kwargs["past_key_values"][i], (0, 0, 0, pad_amount)
@@ -335,8 +347,9 @@ def _pad_past_key_values(self, model_kwargs):
else:
for i in range(len(model_kwargs["past_key_values"])): # layer
for j in range(len(model_kwargs["past_key_values"][i])): # k or v
- if torch.is_tensor(
- model_kwargs["past_key_values"][i][j]
+ if (
+ torch.is_tensor(model_kwargs["past_key_values"][i][j])
+ and model_kwargs["past_key_values"][i][j].shape[-2] == kv_cache_len - pad_amount
): # tensor(batch_size, n_heads, kv_cache_len, head_dim)
model_kwargs["past_key_values"][i][j] = torch.nn.functional.pad(
model_kwargs["past_key_values"][i][j], (0, 0, 0, pad_amount)
@@ -452,6 +465,14 @@ def update_model_kwargs_for_bucketing(
)
else:
assert False, "Not tested for cases where attn_mask isnt passed"
+
+ if model_kwargs.get("cross_attention_mask") is not None:
+ model_kwargs["cross_attention_mask"] = torch.nn.functional.pad(
+ model_kwargs["cross_attention_mask"],
+ (0, 0, 0, 0, 0, pad_amount),
+ value=0,
+ )
+
if reduce_recompile and params["passnum"] == 0:
position_ids_cpu = model_kwargs["attention_mask"].long().cumsum(-1) - 1
position_ids_cpu.masked_fill_(model_kwargs["attention_mask"] == 0, 1)
@@ -494,14 +515,20 @@ def create_pad_arg(pad_amount, i, j):
# This is a necessary (but not sufficient) condition: what ever dimension we are padding, should be a multiple of bucket_size
# This check is added in case we get a new model with a new kv-cache structure, and we attempt to pad some wrong dimension
# in peft case, if there's virtual token. the model_kwargs["past_key_values"][i][j].shape[-(len(pad_tuple) // 2)] % bucket_size == num_virtual_token, no need of assert, the pad length of past_key_value should be aligned with input id and attention_mask
- num_virtual_tokens = model_kwargs.get("num_virtual_tokens", 0)
- assert (
- model_kwargs["past_key_values"][i][j].shape[-(len(pad_tuple) // 2)] % bucket_size
- == num_virtual_tokens
- )
- tmp_lst[j] = torch.nn.functional.pad(
- model_kwargs["past_key_values"][i][j], pad_tuple, value=pad_token_id
- )
+ if (
+ model_kwargs["past_key_values"][i][j].shape[-(len(pad_tuple) // 2)]
+ == params["allocated_space"] - pad_amount
+ ):
+ num_virtual_tokens = model_kwargs.get("num_virtual_tokens", 0)
+ assert (
+ model_kwargs["past_key_values"][i][j].shape[-(len(pad_tuple) // 2)] % bucket_size
+ == num_virtual_tokens
+ )
+ tmp_lst[j] = torch.nn.functional.pad(
+ model_kwargs["past_key_values"][i][j], pad_tuple, value=pad_token_id
+ )
+ else:
+ tmp_lst[j] = model_kwargs["past_key_values"][i][j]
new_kv[i] = tuple(tmp_lst)
model_kwargs["past_key_values"] = tuple(new_kv)
@@ -1101,13 +1128,22 @@ def generate(
(0, generation_config.max_new_tokens),
value=0,
)
+ if model_kwargs.get("cross_attention_mask") is not None:
+ model_kwargs["cross_attention_mask"] = torch.nn.functional.pad(
+ model_kwargs["cross_attention_mask"],
+ (0, 0, 0, 0, 0, generation_config.max_new_tokens),
+ value=0,
+ )
else:
assert generation_config.bucket_size <= 0, "Untested path for bucket>0"
- token_idx = 1
+ if model_kwargs.get("decoder_input_ids", None) is None:
+ token_idx = 1
+ else:
+ token_idx = model_kwargs["decoder_input_ids"].shape[-1]
model_kwargs["token_idx"] = torch.tensor(token_idx, device=inputs_tensor.device)
if model_kwargs.get("decoder_attention_mask", None) is None and generation_config.use_cache:
max_length = (
- generation_config.max_new_tokens + 1
+ generation_config.max_new_tokens + token_idx
if generation_config.max_new_tokens is not None
else generation_config.max_length
)
diff --git a/optimum/habana/transformers/modeling_utils.py b/optimum/habana/transformers/modeling_utils.py
index 2fd24148be..bdc5879bbf 100644
--- a/optimum/habana/transformers/modeling_utils.py
+++ b/optimum/habana/transformers/modeling_utils.py
@@ -30,6 +30,7 @@
GAUDI_WHISPER_ATTENTION_CLASSES,
DeciLMConfig,
DeciLMForCausalLM,
+ Gaudi2Idefics2ImageProcessor,
GaudiBloomForCausalLM,
GaudiBloomMLP,
GaudiCLIPAttention,
@@ -40,6 +41,8 @@
GaudiCLIPVisionTransformer,
GaudiCodeGenAttention,
GaudiCodeGenForCausalLM,
+ GaudiCohereDecoderLayer,
+ GaudiCohereForCausalLM,
GaudiFalconAttention,
GaudiFalconDecoderLayer,
GaudiFalconForCausalLM,
@@ -64,6 +67,9 @@
GaudiGPTNeoXAttention,
GaudiGPTNeoXForCausalLM,
GaudiGPTNeoXLayer,
+ GaudiIdefics2ForConditionalGeneration,
+ GaudiIdefics2Model,
+ GaudiIdefics2VisionEmbeddings,
GaudiLlamaAttention,
GaudiLlamaDecoderLayer,
GaudiLlamaDynamicNTKScalingRotaryEmbedding,
@@ -82,6 +88,14 @@
GaudiMixtralDecoderLayer,
GaudiMixtralForCausalLM,
GaudiMixtralModel,
+ GaudiMllamaCrossAttentionDecoderLayer,
+ GaudiMllamaForCausalLM,
+ GaudiMllamaForConditionalGeneration,
+ GaudiMllamaSelfAttentionDecoderLayer,
+ GaudiMllamaTextCrossAttention,
+ GaudiMllamaTextModel,
+ GaudiMllamaTextSelfAttention,
+ GaudiMllamaVisionModel,
GaudiMptAttention,
GaudiMptBlock,
GaudiMptForCausalLM,
@@ -117,6 +131,7 @@
GaudiWhisperForConditionalGeneration,
GaudiWhisperModel,
GaudiWhisperSdpaAttention,
+ GaudiXGLMForCausalLM,
LlamaConfig,
MistralConfig,
MixtralConfig,
@@ -150,6 +165,8 @@
gaudi_check_and_enable_sdpa,
gaudi_codegen_block_forward,
gaudi_codegen_model_forward,
+ gaudi_cohere_attention_forward,
+ gaudi_cohere_model_forward,
gaudi_conv1d_forward,
gaudi_DetrConvModel_forward,
gaudi_esm_for_protein_folding_forward,
@@ -212,6 +229,9 @@
gaudi_wav2vec2_forward,
gaudi_wav2vec2_tdnnlayer_forward,
gaudi_wav2vec2forctc_forward,
+ gaudi_xglm_attention_forward,
+ gaudi_xglm_decoder_layer_forward,
+ gaudi_xglm_model_forward,
)
@@ -410,6 +430,14 @@ def adapt_transformers_to_gaudi():
GaudiLlavaNextForConditionalGeneration
)
+ # Optimization for idefics2 on Gaudi
+ transformers.models.idefics2.modeling_idefics2.Idefics2ForConditionalGeneration = (
+ GaudiIdefics2ForConditionalGeneration
+ )
+ transformers.models.idefics2.modeling_idefics2.Idefics2Model = GaudiIdefics2Model
+ transformers.models.idefics2.image_processing_idefics2.Idefics2ImageProcessor = Gaudi2Idefics2ImageProcessor
+ transformers.models.idefics2.modeling_idefics2.Idefics2VisionEmbeddings = GaudiIdefics2VisionEmbeddings
+
# Optimization for Clip on Gaudi
transformers.models.clip.modeling_clip.CLIPVisionEmbeddings = GaudiCLIPVisionEmbeddings
transformers.models.clip.modeling_clip.CLIPAttention = GaudiCLIPAttention
@@ -602,5 +630,27 @@ def adapt_transformers_to_gaudi():
transformers.models.whisper.modeling_whisper.WhisperForConditionalGeneration = GaudiWhisperForConditionalGeneration
transformers.models.whisper.modeling_whisper.WHISPER_ATTENTION_CLASSES = GAUDI_WHISPER_ATTENTION_CLASSES
+ # Optimization for mllama on Gaudi
+ transformers.models.mllama.modeling_mllama.MllamaSelfAttentionDecoderLayer = GaudiMllamaSelfAttentionDecoderLayer
+ transformers.models.mllama.modeling_mllama.MllamaCrossAttentionDecoderLayer = GaudiMllamaCrossAttentionDecoderLayer
+ transformers.models.mllama.modeling_mllama.MllamaForCausalLM = GaudiMllamaForCausalLM
+ transformers.models.mllama.modeling_mllama.MllamaTextSelfAttention = GaudiMllamaTextSelfAttention
+ transformers.models.mllama.modeling_mllama.MllamaTextCrossAttention = GaudiMllamaTextCrossAttention
+ transformers.models.mllama.modeling_mllama.MllamaForConditionalGeneration = GaudiMllamaForConditionalGeneration
+ transformers.models.mllama.modeling_mllama.MllamaTextModel = GaudiMllamaTextModel
+ transformers.models.mllama.modeling_mllama.MllamaVisionModel = GaudiMllamaVisionModel
+
transformers.AutoConfig.register("deci", DeciLMConfig)
transformers.AutoModelForCausalLM.register(DeciLMConfig, DeciLMForCausalLM)
+
+ # Optimization for cohere on Gaudi
+ transformers.models.cohere.modeling_cohere.CohereDecoderLayer = GaudiCohereDecoderLayer
+ transformers.models.cohere.modeling_cohere.CohereForCausalLM = GaudiCohereForCausalLM
+ transformers.models.cohere.modeling_cohere.CohereModel.forward = gaudi_cohere_model_forward
+ transformers.models.cohere.modeling_cohere.CohereAttention.forward = gaudi_cohere_attention_forward
+
+ # Optimization for xglm on Gaudi
+ transformers.models.xglm.modeling_xglm.XGLMForCausalLM = GaudiXGLMForCausalLM
+ transformers.models.xglm.modeling_xglm.XGLMModel.forward = gaudi_xglm_model_forward
+ transformers.models.xglm.modeling_xglm.XGLMAttention.forward = gaudi_xglm_attention_forward
+ transformers.models.xglm.modeling_xglm.XGLMDecoderLayer.forward = gaudi_xglm_decoder_layer_forward
diff --git a/optimum/habana/transformers/models/__init__.py b/optimum/habana/transformers/models/__init__.py
index 8c9a045efa..94ee04b5b4 100644
--- a/optimum/habana/transformers/models/__init__.py
+++ b/optimum/habana/transformers/models/__init__.py
@@ -45,6 +45,12 @@
gaudi_codegen_block_forward,
gaudi_codegen_model_forward,
)
+from .cohere import (
+ GaudiCohereDecoderLayer,
+ GaudiCohereForCausalLM,
+ gaudi_cohere_attention_forward,
+ gaudi_cohere_model_forward,
+)
from .decilm import (
DeciLMConfig,
DeciLMForCausalLM,
@@ -103,6 +109,12 @@
GaudiGPTJForCausalLM,
GaudiGPTJModel,
)
+from .idefics2 import (
+ Gaudi2Idefics2ImageProcessor,
+ GaudiIdefics2ForConditionalGeneration,
+ GaudiIdefics2Model,
+ GaudiIdefics2VisionEmbeddings,
+)
from .llama import (
GaudiLlamaAttention,
GaudiLlamaDecoderLayer,
@@ -138,6 +150,16 @@
gaudi_mixtral_block_sparse_moe_forward,
gaudi_mixtral_rmsnorm_forward,
)
+from .mllama import (
+ GaudiMllamaCrossAttentionDecoderLayer,
+ GaudiMllamaForCausalLM,
+ GaudiMllamaForConditionalGeneration,
+ GaudiMllamaSelfAttentionDecoderLayer,
+ GaudiMllamaTextCrossAttention,
+ GaudiMllamaTextModel,
+ GaudiMllamaTextSelfAttention,
+ GaudiMllamaVisionModel,
+)
from .modeling_all_models import (
gaudi_check_and_enable_sdpa,
gaudi_conv1d_forward,
@@ -250,3 +272,9 @@
GaudiWhisperModel,
GaudiWhisperSdpaAttention,
)
+from .xglm import (
+ GaudiXGLMForCausalLM,
+ gaudi_xglm_attention_forward,
+ gaudi_xglm_decoder_layer_forward,
+ gaudi_xglm_model_forward,
+)
diff --git a/optimum/habana/transformers/models/clip/modeling_clip.py b/optimum/habana/transformers/models/clip/modeling_clip.py
index 96b03ab32a..ef5d604ec9 100644
--- a/optimum/habana/transformers/models/clip/modeling_clip.py
+++ b/optimum/habana/transformers/models/clip/modeling_clip.py
@@ -86,7 +86,7 @@ def forward(
- add new args use_flash_attention to enable FusedSDPA
- add new args flash_attention_recompute
"""
- bsz, tgt_len, embed_dim = hidden_states.size()
+ bsz, tgt_len, _ = hidden_states.size()
attn_weights_reshaped = None
# get query proj
query_states = self.q_proj(hidden_states) * self.scale
@@ -156,7 +156,7 @@ def forward(
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(1, 2)
- attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
+ attn_output = attn_output.reshape(bsz, tgt_len, -1)
attn_output = self.out_proj(attn_output)
diff --git a/optimum/habana/transformers/models/cohere/__init__.py b/optimum/habana/transformers/models/cohere/__init__.py
new file mode 100644
index 0000000000..ec3a43831c
--- /dev/null
+++ b/optimum/habana/transformers/models/cohere/__init__.py
@@ -0,0 +1,6 @@
+from .modeling_cohere import (
+ GaudiCohereDecoderLayer,
+ GaudiCohereForCausalLM,
+ gaudi_cohere_attention_forward,
+ gaudi_cohere_model_forward,
+)
diff --git a/optimum/habana/transformers/models/cohere/modeling_cohere.py b/optimum/habana/transformers/models/cohere/modeling_cohere.py
new file mode 100644
index 0000000000..c0785c88ed
--- /dev/null
+++ b/optimum/habana/transformers/models/cohere/modeling_cohere.py
@@ -0,0 +1,441 @@
+import math
+from typing import List, Optional, Tuple, Union
+
+import torch
+from torch import nn
+from torch.nn import CrossEntropyLoss
+from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
+from transformers.models.cohere.modeling_cohere import (
+ Cache,
+ CohereAttention,
+ CohereConfig,
+ CohereDecoderLayer,
+ CohereForCausalLM,
+ CohereLayerNorm,
+ CohereMLP,
+ DynamicCache,
+ StaticCache,
+ apply_rotary_pos_emb,
+ logger,
+ repeat_kv,
+)
+
+from ...modeling_attn_mask_utils import _gaudi_prepare_4d_causal_attention_mask
+
+
+def gaudi_cohere_attention_forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ token_idx: Optional[torch.Tensor] = None,
+) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ """
+ Copied from CohereAttention.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/cohere/modeling_cohere.py
+ The only differences are:
+ - add new args token_idx
+ - optimize KV cache
+ """
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
+ if self.use_qk_norm:
+ query_states = self.q_norm(query_states)
+ key_states = self.k_norm(key_states)
+
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; position_ids needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ if token_idx is not None:
+ if len(past_key_value.key_cache) <= self.layer_idx:
+ past_key_value.key_cache.append(key_states)
+ past_key_value.value_cache.append(value_states)
+ else:
+ past_key_value.key_cache[self.layer_idx].index_copy_(2, token_idx - 1, key_states)
+ past_key_value.value_cache[self.layer_idx].index_copy_(2, token_idx - 1, value_states)
+ key_states = past_key_value.key_cache[self.layer_idx]
+ value_states = past_key_value.value_cache[self.layer_idx]
+ else:
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
+
+ if attention_mask is not None: # no matter the length, we just slice it
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+class GaudiCohereDecoderLayer(CohereDecoderLayer):
+ def __init__(self, config: CohereConfig, layer_idx: int):
+ super(CohereDecoderLayer, self).__init__()
+ self.hidden_size = config.hidden_size
+
+ self.self_attn = CohereAttention(config=config, layer_idx=layer_idx)
+
+ self.mlp = CohereMLP(config)
+ self.input_layernorm = CohereLayerNorm(hidden_size=(config.hidden_size), eps=config.layer_norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ token_idx: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ """
+ Copied from CohereDecoderLayer.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/cohere/modeling_cohere.py
+ The only differences are:
+ - add new args token_idx
+ """
+ residual = hidden_states
+
+ hidden_states = self.input_layernorm(hidden_states)
+
+ # Self Attention
+ hidden_states_attention, self_attn_weights, present_key_value = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ token_idx=token_idx,
+ )
+
+ # Fully Connected
+ hidden_states_mlp = self.mlp(hidden_states)
+
+ # Add everything together
+ hidden_states = residual + hidden_states_attention + hidden_states_mlp
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ if use_cache:
+ outputs += (present_key_value,)
+
+ return outputs
+
+
+def gaudi_cohere_model_forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ token_idx: Optional[torch.Tensor] = None,
+) -> Union[Tuple, BaseModelOutputWithPast]:
+ """
+ Copied from CohereModel.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/cohere/modeling_cohere.py
+ The only differences are:
+ - add new args token_idx
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError(
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
+ )
+
+ if self.gradient_checkpointing and self.training and use_cache:
+ logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.")
+ use_cache = False
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ past_seen_tokens = 0
+ return_legacy_cache = False
+ if (
+ use_cache and not isinstance(past_key_values, Cache) and not self.training
+ ): # kept for BC (non `Cache` `past_key_values` inputs)
+ return_legacy_cache = True
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ causal_mask = self._update_causal_mask(
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
+ )
+
+ # embed positions
+ hidden_states = inputs_embeds
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ next_decoder_cache = None
+
+ for decoder_layer in self.layers:
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ decoder_layer.__call__,
+ hidden_states,
+ causal_mask,
+ position_ids,
+ past_key_values,
+ output_attentions,
+ use_cache,
+ cache_position,
+ )
+ else:
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ token_idx=token_idx,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ hidden_states = self.norm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ next_cache = next_decoder_cache if use_cache else None
+ if return_legacy_cache:
+ next_cache = next_cache.to_legacy_cache()
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+
+class GaudiCohereForCausalLM(CohereForCausalLM):
+ """
+ Inherits from CohereForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/cohere/modeling_cohere.py
+ The only differences are:
+ - add new args token_idx
+ - add token_idx into model_inputs
+ - from step2 when enable KV cache, slice next_input_ids from input_ids base on the token_idx
+ - from step2 when enable KV cache, slice next_position_ids from position_ids base on the token_idx
+ """
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ token_idx: Optional[torch.Tensor] = None,
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ cache_position=cache_position,
+ token_idx=token_idx,
+ )
+
+ hidden_states = outputs[0]
+ logits = self.lm_head(hidden_states)
+ logits = logits * self.logit_scale
+ logits = logits.float()
+
+ loss = None
+ if labels is not None:
+ # Shift so that tokens < n predict n
+ shift_logits = logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ # Flatten the tokens
+ loss_fct = CrossEntropyLoss()
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
+ shift_labels = shift_labels.view(-1)
+ # Enable model parallelism
+ shift_labels = shift_labels.to(shift_logits.device)
+ loss = loss_fct(shift_logits, shift_labels)
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return (loss,) + output if loss is not None else output
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ past_key_values=None,
+ attention_mask=None,
+ inputs_embeds=None,
+ cache_position=None,
+ position_ids=None,
+ use_cache=True,
+ **kwargs,
+ ):
+ token_idx = kwargs.get("token_idx", None)
+ if past_key_values is not None:
+ if token_idx is None:
+ if inputs_embeds is not None: # Exception 1
+ input_ids = input_ids[:, -cache_position.shape[0] :]
+ elif (
+ input_ids.shape[1] != cache_position.shape[0]
+ ): # Default case (the "else", a no op, is Exception 2)
+ input_ids = input_ids[:, cache_position]
+ else:
+ input_ids = torch.index_select(input_ids, 1, token_idx - 1)
+
+ if attention_mask is not None and position_ids is None:
+ # create position_ids on the fly for batch generation
+ position_ids = attention_mask.long().cumsum(-1) - 1
+ position_ids.masked_fill_(attention_mask == 0, 1)
+ if past_key_values:
+ if token_idx is not None:
+ position_ids = torch.index_select(position_ids, 1, token_idx - 1)
+ else:
+ position_ids = position_ids[:, -input_ids.shape[1] :]
+
+ # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
+ position_ids = position_ids.clone(memory_format=torch.contiguous_format)
+
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
+ if inputs_embeds is not None and cache_position[0] == 0:
+ model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
+ else:
+ # The clone here is for the same reason as for `position_ids`.
+ model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
+
+ if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
+ if model_inputs["inputs_embeds"] is not None:
+ batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
+ device = model_inputs["inputs_embeds"].device
+ else:
+ batch_size, sequence_length = model_inputs["input_ids"].shape
+ device = model_inputs["input_ids"].device
+
+ dtype = self.lm_head.weight.dtype
+ min_dtype = torch.finfo(dtype).min
+
+ attention_mask = _gaudi_prepare_4d_causal_attention_mask(
+ attention_mask,
+ sequence_length=sequence_length,
+ target_length=past_key_values.get_max_length(),
+ dtype=dtype,
+ device=device,
+ min_dtype=min_dtype,
+ cache_position=cache_position,
+ batch_size=batch_size,
+ )
+
+ model_inputs.update(
+ {
+ "position_ids": position_ids,
+ "cache_position": cache_position,
+ "past_key_values": past_key_values,
+ "use_cache": use_cache,
+ "attention_mask": attention_mask,
+ "token_idx": token_idx,
+ }
+ )
+ return model_inputs
diff --git a/optimum/habana/transformers/models/falcon/modeling_falcon.py b/optimum/habana/transformers/models/falcon/modeling_falcon.py
index 60b7899e72..f4c9d454ab 100644
--- a/optimum/habana/transformers/models/falcon/modeling_falcon.py
+++ b/optimum/habana/transformers/models/falcon/modeling_falcon.py
@@ -578,7 +578,7 @@ def attention_all_reduce(self, attn_output):
def post_attn_forward(self, attn_output):
if hasattr(self.dense, "all_reduce"):
- self.dense.post_all_reduce(attn_output)
+ return self.dense.post_all_reduce(attn_output)
return attn_output
@@ -598,7 +598,7 @@ def mlp_all_reduce(self, x):
def post_mlp_forward(self, x):
if hasattr(self.dense_4h_to_h, "all_reduce"):
- self.dense_4h_to_h.post_all_reduce(x)
+ return self.dense_4h_to_h.post_all_reduce(x)
return x
diff --git a/optimum/habana/transformers/models/gemma/modeling_gemma.py b/optimum/habana/transformers/models/gemma/modeling_gemma.py
index 1c270b62f6..67ea4e4a62 100755
--- a/optimum/habana/transformers/models/gemma/modeling_gemma.py
+++ b/optimum/habana/transformers/models/gemma/modeling_gemma.py
@@ -20,6 +20,7 @@
"""PyTorch Gemma model."""
import math
+import os
from typing import List, Optional, Tuple, Union
import torch
@@ -214,7 +215,7 @@ def pre_attn_forward(
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
- cache_idx: int = None,
+ cache_idx: Optional[int] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""
@@ -289,7 +290,8 @@ def pre_attn_forward(
if q_len == 1:
# next token
- with ht.sdp_kernel(enable_recompute=False):
+ use_recompute = True if os.getenv("QUANT_CONFIG", "") else False
+ with ht.sdp_kernel(enable_recompute=use_recompute):
attn_output = FusedSDPA.apply(
query_states, key_states, value_states, attention_mask, 0.0, False, None
)
@@ -357,7 +359,7 @@ def attention_all_reduce(self, attn_output):
def post_attn_forward(self, attn_output):
if hasattr(self.o_proj, "post_all_reduce"):
- self.o_proj.post_all_reduce(attn_output)
+ return self.o_proj.post_all_reduce(attn_output)
return attn_output
@@ -407,23 +409,23 @@ def pre_attn(
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
- cache_idx: int = None,
+ cache_idx: Optional[int] = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
hidden_states = self.input_layernorm(hidden_states)
hidden_states, attn_weights, present_key_value = self.self_attn.pre_attn_forward(
- hidden_states,
- attention_mask,
- position_ids,
- past_key_value,
- output_attentions,
- use_cache,
- cache_position,
- token_idx,
- attn_softmax_bf16,
- reuse_cache,
- use_flash_attention,
- flash_attention_recompute,
- flash_attention_causal_mask,
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ token_idx=token_idx,
+ attn_softmax_bf16=attn_softmax_bf16,
+ reuse_cache=reuse_cache,
+ use_flash_attention=use_flash_attention,
+ flash_attention_recompute=flash_attention_recompute,
+ flash_attention_causal_mask=flash_attention_causal_mask,
cache_idx=cache_idx,
)
return hidden_states, attn_weights, present_key_value
@@ -443,7 +445,7 @@ def forward(
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
- cache_idx: int = None,
+ cache_idx: Optional[int] = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Copied from GemmaDecoderLayer.forward: https://github.com/huggingface/transformers/blob/v4.38.1/src/transformers/models/gemma/modeling_gemma.py
@@ -453,16 +455,16 @@ def forward(
residual = hidden_states
hidden_states, self_attn_weights, present_key_value = self.pre_attn(
- hidden_states,
- attention_mask,
- position_ids,
- past_key_value,
- output_attentions,
- use_cache,
- cache_position,
- token_idx,
- attn_softmax_bf16,
- reuse_cache,
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ token_idx=token_idx,
+ attn_softmax_bf16=attn_softmax_bf16,
+ reuse_cache=reuse_cache,
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
flash_attention_causal_mask=flash_attention_causal_mask,
@@ -717,6 +719,7 @@ def forward(
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
+ reuse_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
@@ -746,6 +749,7 @@ def forward(
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
+ reuse_cache=reuse_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
@@ -859,9 +863,13 @@ def prepare_inputs_for_generation(
"cache_position": cache_position,
"past_key_values": past_key_values,
"use_cache": use_cache,
+ "reuse_cache": kwargs.get("reuse_cache"),
"attention_mask": attention_mask,
"num_logits_to_keep": num_logits_to_keep,
"token_idx": token_idx,
+ "use_flash_attention": kwargs.get("use_flash_attention"),
+ "flash_attention_recompute": kwargs.get("flash_attention_recompute"),
+ "flash_attention_causal_mask": kwargs.get("flash_attention_causal_mask"),
}
)
return model_inputs
diff --git a/optimum/habana/transformers/models/idefics2/__init__.py b/optimum/habana/transformers/models/idefics2/__init__.py
new file mode 100644
index 0000000000..5c52749e7a
--- /dev/null
+++ b/optimum/habana/transformers/models/idefics2/__init__.py
@@ -0,0 +1,2 @@
+from .image_processing_idefics2 import Gaudi2Idefics2ImageProcessor
+from .modeling_idefics2 import GaudiIdefics2ForConditionalGeneration, GaudiIdefics2Model, GaudiIdefics2VisionEmbeddings
diff --git a/optimum/habana/transformers/models/idefics2/image_processing_idefics2.py b/optimum/habana/transformers/models/idefics2/image_processing_idefics2.py
new file mode 100644
index 0000000000..3bf45a36de
--- /dev/null
+++ b/optimum/habana/transformers/models/idefics2/image_processing_idefics2.py
@@ -0,0 +1,84 @@
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from typing import Iterable, List, Optional, Union
+
+import numpy as np
+from transformers.image_processing_utils import BatchFeature
+from transformers.image_utils import ChannelDimension, infer_channel_dimension_format
+from transformers.models.idefics2.image_processing_idefics2 import (
+ Idefics2ImageProcessor,
+ get_max_height_width,
+ make_pixel_mask,
+)
+from transformers.utils import TensorType
+
+
+class Gaudi2Idefics2ImageProcessor(Idefics2ImageProcessor):
+ def pad(
+ self,
+ images: List[np.ndarray],
+ constant_values: Union[float, Iterable[float]] = 0,
+ return_pixel_mask: bool = True,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ data_format: Optional[ChannelDimension] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ) -> BatchFeature:
+ """
+ Inherits from Idefics2ImageProcessor::pad https://github.com/huggingface/transformers/blob/v4.43.4/src/transformers/models/idefics2/image_processing_idefics2.py#L314
+ The only differences are:
+ - pad size use longest_edge, so the image size will not change, aims to accelerate finetune speed
+ """
+
+ if getattr(self, "pad_to_longest_edge", False):
+ pad_size = (self.size["longest_edge"], self.size["longest_edge"])
+ else:
+ pad_size = get_max_height_width(images, input_data_format=input_data_format)
+
+ batch_size = len(images)
+ max_num_images = max(len(images_) for images_ in images)
+ input_data_format = (
+ infer_channel_dimension_format(images[0][0]) if input_data_format is None else input_data_format
+ )
+ data_format = input_data_format if data_format is None else data_format
+
+ def empty_image(size, input_data_format):
+ if input_data_format == ChannelDimension.FIRST:
+ return np.zeros((3, *size), dtype=np.uint8)
+ elif input_data_format == ChannelDimension.LAST:
+ return np.zeros((*size, 3), dtype=np.uint8)
+ raise ValueError("Invalid channel dimension format.")
+
+ padded_images_list = [
+ [empty_image(pad_size, data_format) for _ in range(max_num_images)] for _ in range(batch_size)
+ ]
+ padded_masks = [[np.zeros(pad_size) for _ in range(max_num_images)] for _ in range(batch_size)]
+
+ for batch_idx in range(batch_size):
+ for sample_idx, image in enumerate(images[batch_idx]):
+ padded_images_list[batch_idx][sample_idx] = self._pad_image(
+ image,
+ pad_size,
+ constant_values=constant_values,
+ data_format=data_format,
+ input_data_format=input_data_format,
+ )
+ padded_masks[batch_idx][sample_idx] = make_pixel_mask(
+ image, output_size=pad_size, input_data_format=input_data_format
+ )
+
+ padded_masks = padded_masks if return_pixel_mask else None
+ return padded_images_list, padded_masks
diff --git a/optimum/habana/transformers/models/idefics2/modeling_idefics2.py b/optimum/habana/transformers/models/idefics2/modeling_idefics2.py
new file mode 100644
index 0000000000..7b92bca9c3
--- /dev/null
+++ b/optimum/habana/transformers/models/idefics2/modeling_idefics2.py
@@ -0,0 +1,510 @@
+# coding=utf-8
+# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch Idefics2 model."""
+
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.utils.checkpoint
+from torch.nn import CrossEntropyLoss
+from transformers.cache_utils import Cache, DynamicCache
+from transformers.models.idefics2.modeling_idefics2 import (
+ Idefics2BaseModelOutputWithPast,
+ Idefics2CausalLMOutputWithPast,
+ Idefics2ForConditionalGeneration,
+ Idefics2Model,
+ Idefics2VisionEmbeddings,
+)
+from transformers.utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class GaudiIdefics2VisionEmbeddings(Idefics2VisionEmbeddings):
+ def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor) -> torch.Tensor:
+ """
+ Inherits from Idefics2VisionEmbeddings::forward https://github.com/huggingface/transformers/blob/v4.43.4/src/transformers/models/idefics2/modeling_idefics2.py#L159
+ The only differences are:
+ - add int() in nb_patches_h. nb_patches_w to avoid overflow in torch.arange. sometimes return shape is nb_patches_h/nb_patch_w + 1
+ - delete to("cpu") of p_attn_mask
+ """
+
+ batch_size, _, max_im_h, max_im_w = pixel_values.shape
+
+ patch_embeds = self.patch_embedding(pixel_values)
+ embeddings = patch_embeds.flatten(2).transpose(1, 2)
+
+ max_nb_patches_h, max_nb_patches_w = max_im_h // self.patch_size, max_im_w // self.patch_size
+ boundaries = torch.arange(1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side)
+ position_ids = torch.full(
+ size=(batch_size, max_nb_patches_h * max_nb_patches_w),
+ fill_value=0,
+ device=self.position_embedding.weight.device,
+ )
+
+ for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
+ nb_patches_h = int(p_attn_mask[:, 0].sum())
+ nb_patches_w = int(p_attn_mask[0].sum())
+
+ fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)
+ fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w)
+
+ bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
+ bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
+
+ pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w).flatten()
+ position_ids[batch_idx][p_attn_mask.view(-1)] = pos_ids.to(self.position_embedding.weight.device)
+ embeddings = embeddings + self.position_embedding(position_ids)
+ return embeddings
+
+
+class GaudiIdefics2Model(Idefics2Model):
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ pixel_attention_mask: Optional[torch.BoolTensor] = None,
+ image_hidden_states: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, Idefics2BaseModelOutputWithPast]:
+ """
+ Inherits from Idefics2Model::forward https://github.com/huggingface/transformers/blob/v4.43.4/src/transformers/models/idefics2/modeling_idefics2.py#L1303
+ The only differences are:
+ - ignoring new Cache path for HPU
+ - unfold is not supported in HPU, replace with conv2d
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if self.training and self.text_model.gradient_checkpointing and use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ # retrieve input_ids and inputs_embeds
+ if input_ids is not None:
+ batch_size, seq_length = input_ids.shape
+ elif inputs_embeds is not None:
+ batch_size, seq_length, _ = inputs_embeds.shape
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ past_seen_tokens = 0
+ return_legacy_cache = True
+ use_new_cache = False # Ignoring new Cache path for HPU
+ if use_cache and use_new_cache:
+ if not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
+ return_legacy_cache = True
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
+ past_seen_tokens = past_key_values.get_seq_length()
+ else:
+ if past_key_values is not None:
+ past_seen_tokens = past_key_values[0][0].shape[2]
+ if inputs_embeds is not None and input_ids is None and past_seen_tokens == 0:
+ raise ValueError("When first calling the model, if input_embeds are passed, input_ids should not be None.")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.text_model.get_input_embeddings()(input_ids)
+
+ # START VISUAL INPUTS INTEGRATION
+ if pixel_values is not None and image_hidden_states is not None:
+ raise ValueError("You cannot specify both pixel_values and image_hidden_states at the same time")
+ elif pixel_values is not None:
+ batch_size, num_images, num_channels, height, width = pixel_values.shape
+ pixel_values = pixel_values.to(dtype=self.dtype) # fp16 compatibility
+ pixel_values = pixel_values.view(batch_size * num_images, *pixel_values.shape[2:])
+
+ # Remove padding images - padding images are full 0.
+ nb_values_per_image = pixel_values.shape[1:].numel()
+ real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image
+ pixel_values = pixel_values[real_images_inds].contiguous()
+
+ # Handle the vision attention mask
+ if pixel_attention_mask is None:
+ pixel_attention_mask = torch.ones(
+ size=(pixel_values.size(0), pixel_values.size(2), pixel_values.size(3)),
+ dtype=torch.bool,
+ device=pixel_values.device,
+ )
+ else:
+ # Remove padding images from the mask/pP p
+ pixel_attention_mask = pixel_attention_mask.view(
+ batch_size * num_images, *pixel_attention_mask.shape[2:]
+ )
+ pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous()
+
+ patch_size = self.config.vision_config.patch_size
+ conv_kernel = torch.ones(
+ [1, 1, patch_size, patch_size], dtype=pixel_values.dtype, device=pixel_values.device
+ )
+ patches_subgrid = torch.nn.functional.conv2d(
+ pixel_attention_mask.unsqueeze(1).to(conv_kernel.dtype), conv_kernel, stride=patch_size
+ ).squeeze(1)
+ patch_attention_mask = torch.eq(patches_subgrid, (patch_size * patch_size))
+
+ # Get sequence from the vision encoder
+ image_hidden_states = self.vision_model(
+ pixel_values=pixel_values,
+ patch_attention_mask=patch_attention_mask,
+ ).last_hidden_state
+
+ # Modality projection & resampling
+ image_hidden_states = self.connector(
+ image_hidden_states, attention_mask=patch_attention_mask.view(pixel_values.size(0), -1)
+ )
+
+ elif image_hidden_states is not None:
+ image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device)
+
+ if past_seen_tokens == 0 and inputs_embeds is not None and image_hidden_states is not None:
+ # When we generate, we don't want to replace the potential image_token_id that we generated by images
+ # that simply don't exist
+ inputs_embeds = self.inputs_merger(
+ input_ids=input_ids,
+ inputs_embeds=inputs_embeds,
+ image_hidden_states=image_hidden_states,
+ )
+
+ outputs = self.text_model(
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ if return_legacy_cache and use_cache:
+ if return_dict:
+ outputs.past_key_values = (
+ outputs.past_key_values.to_legacy_cache()
+ if isinstance(outputs.past_key_values, Cache)
+ else outputs.past_key_values
+ )
+ else:
+ outputs[1] = outputs[1].to_legacy_cache() if isinstance(outputs[1], Cache) else outputs[1]
+
+ if not return_dict:
+ return tuple(v for v in [*outputs, image_hidden_states] if v is not None)
+
+ return Idefics2BaseModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ image_hidden_states=image_hidden_states,
+ )
+
+ def inputs_merger(
+ self,
+ input_ids: torch.LongTensor,
+ inputs_embeds: Optional[torch.Tensor],
+ image_hidden_states: Optional[torch.Tensor],
+ ):
+ """
+ Inherits from Idefics2Model::inputs_merger https://github.com/huggingface/transformers/blob/v4.43.4/src/transformers/models/idefics2/modeling_idefics2.py#L1268
+ The only differences are:
+ - replace `==` with torch.where to fix the issue in hpu graph
+ """
+ num_images, _, vision_hidden_size = image_hidden_states.shape
+ special_image_token_mask = torch.where(input_ids == self.image_token_id)
+ new_inputs_embeds = inputs_embeds.clone()
+ reshaped_image_hidden_states = image_hidden_states.view(-1, vision_hidden_size)
+ new_inputs_embeds[special_image_token_mask] = reshaped_image_hidden_states
+ return new_inputs_embeds
+
+
+class GaudiIdefics2ForConditionalGeneration(Idefics2ForConditionalGeneration):
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ pixel_attention_mask: Optional[torch.BoolTensor] = None,
+ image_hidden_states: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ token_idx: Optional[torch.Tensor] = None,
+ ) -> Union[Tuple, Idefics2CausalLMOutputWithPast]:
+ """
+ Inherits from Idefics2ForConditionalGeneration::forward https://github.com/huggingface/transformers/blob/v4.43.4/src/transformers/models/idefics2/modeling_idefics2.py#L1505
+ The only differences are:
+ - add new args token_idx
+ """
+ if token_idx is not None:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ if input_ids is not None:
+ batch_size, seq_length = input_ids.shape
+ elif inputs_embeds is not None:
+ batch_size, seq_length, _ = inputs_embeds.shape
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ past_seen_tokens = 0
+ return_legacy_cache = True
+ use_new_cache = False # Ignoring new Cache path for HPU
+ if use_cache and use_new_cache:
+ if not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
+ return_legacy_cache = True
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
+ past_seen_tokens = past_key_values.get_seq_length()
+ else:
+ if past_key_values is not None:
+ past_seen_tokens = past_key_values[0][0].shape[2]
+ if inputs_embeds is not None and input_ids is None and past_seen_tokens == 0:
+ raise ValueError(
+ "When first calling the model, if input_embeds are passed, input_ids should not be None."
+ )
+
+ if inputs_embeds is None:
+ inputs_embeds = self.model.text_model.get_input_embeddings()(input_ids)
+
+ # START VISUAL INPUTS INTEGRATION
+ if pixel_values is not None and image_hidden_states is not None:
+ raise ValueError("You cannot specify both pixel_values and image_hidden_states at the same time")
+ elif image_hidden_states is not None:
+ image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device)
+
+ if past_seen_tokens == 0 and inputs_embeds is not None and image_hidden_states is not None:
+ # When we generate, we don't want to replace the potential image_token_id that we generated by images
+ # that simply don't exist
+ inputs_embeds = self.model.inputs_merger(
+ input_ids=input_ids,
+ inputs_embeds=inputs_embeds,
+ image_hidden_states=image_hidden_states,
+ )
+
+ outputs = self.model.text_model(
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ token_idx=token_idx,
+ )
+
+ if return_legacy_cache and use_cache:
+ if return_dict:
+ outputs.past_key_values = (
+ outputs.past_key_values.to_legacy_cache()
+ if isinstance(outputs.past_key_values, Cache)
+ else outputs.past_key_values
+ )
+ else:
+ outputs[1] = outputs[1].to_legacy_cache() if isinstance(outputs[1], Cache) else outputs[1]
+
+ hidden_states = outputs[0]
+ logits = self.lm_head(hidden_states)
+ logits = logits.float()
+
+ loss = None
+ if labels is not None:
+ labels = labels.to(logits.device)
+ # Shift so that tokens < n predict n
+ if attention_mask is not None:
+ shift_attention_mask = attention_mask[..., 1:].to(logits.device)
+ shift_logits = logits[..., :-1, :][shift_attention_mask != 0].contiguous()
+ shift_labels = labels[..., 1:][shift_attention_mask != 0].contiguous()
+ else:
+ shift_logits = logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ # Flatten the tokens
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return (loss,) + output if loss is not None else output
+
+ return Idefics2CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ image_hidden_states=image_hidden_states,
+ )
+ else:
+ return super().forward(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ pixel_values=pixel_values,
+ pixel_attention_mask=pixel_attention_mask,
+ image_hidden_states=image_hidden_states,
+ labels=labels,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ def prepare_inputs_for_generation(
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
+ ):
+ """
+ Inherits from Idefics2ForConditionalGeneration::forward https://github.com/huggingface/transformers/blob/v4.43.4/src/transformers/models/idefics2/modeling_idefics2.py#L1622
+ The only differences are:
+ - add new args token_idx
+ - add None "Cache" past_key_values support
+ - move vision_model to prepare_input_for_generation
+ """
+ past_length = 0
+ token_idx = kwargs.get("token_idx", None)
+ # Omit tokens covered by past_key_values
+ if past_key_values is not None:
+ # Past key values are always initialized with a `Cache` object -> no need for if-else anymore
+ if isinstance(past_key_values, Cache):
+ past_length = past_key_values.get_seq_length()
+ max_cache_length = past_key_values.get_max_length()
+ else:
+ past_length = past_key_values[0][0].shape[2]
+ max_cache_length = None
+
+ # Keep only the unprocessed tokens:
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
+ # input)
+ if token_idx is not None:
+ input_ids = torch.index_select(input_ids, 1, token_idx - 1)
+ elif attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
+ # input_ids based on the past_length.
+ elif past_length < input_ids.shape[1]:
+ input_ids = input_ids[:, past_length:]
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
+
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
+ if (
+ max_cache_length is not None
+ and attention_mask is not None
+ and past_length + input_ids.shape[1] > max_cache_length
+ ):
+ attention_mask = attention_mask[:, -max_cache_length:]
+
+ position_ids = kwargs.get("position_ids", None)
+ if attention_mask is not None and position_ids is None:
+ # create position_ids on the fly for batch generation
+ position_ids = attention_mask.long().cumsum(-1) - 1
+ position_ids.masked_fill_(attention_mask == 0, 1)
+ if past_key_values:
+ if token_idx is not None:
+ position_ids = torch.index_select(position_ids, 1, token_idx - 1)
+ else:
+ position_ids = position_ids[:, -input_ids.shape[1] :]
+
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
+ if inputs_embeds is not None and past_length == 0:
+ model_inputs = {"inputs_embeds": inputs_embeds}
+ else:
+ model_inputs = {"input_ids": input_ids}
+
+ image_hidden_states = kwargs.get("image_hidden_states", None)
+ if image_hidden_states is not None:
+ pixel_values = None
+ pixel_attention_mask = None
+ else:
+ pixel_values = kwargs.get("pixel_values", None)
+ pixel_attention_mask = kwargs.get("pixel_attention_mask", None)
+
+ if token_idx is not None and pixel_values is not None:
+ batch_size, num_images, num_channels, height, width = pixel_values.shape
+ pixel_values = pixel_values.to(dtype=self.dtype) # fp16 compatibility
+ pixel_values = pixel_values.view(batch_size * num_images, *pixel_values.shape[2:])
+
+ # Remove padding images - padding images are full 0.
+ nb_values_per_image = pixel_values.shape[1:].numel()
+ real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image
+ pixel_values = pixel_values[real_images_inds].contiguous()
+
+ # Handle the vision attention mask
+ if pixel_attention_mask is None:
+ pixel_attention_mask = torch.ones(
+ size=(pixel_values.size(0), pixel_values.size(2), pixel_values.size(3)),
+ dtype=torch.bool,
+ device=pixel_values.device,
+ )
+ else:
+ # Remove padding images from the mask/pP p
+ pixel_attention_mask = pixel_attention_mask.view(
+ batch_size * num_images, *pixel_attention_mask.shape[2:]
+ )
+ pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous()
+
+ patch_size = self.config.vision_config.patch_size
+ conv_kernel = torch.ones(
+ [1, 1, patch_size, patch_size], dtype=pixel_values.dtype, device=pixel_values.device
+ )
+ patches_subgrid = torch.nn.functional.conv2d(
+ pixel_attention_mask.unsqueeze(1).to(conv_kernel.dtype), conv_kernel, stride=patch_size
+ ).squeeze(1)
+ patch_attention_mask = torch.eq(patches_subgrid, (patch_size * patch_size))
+
+ # Get sequence from the vision encoder
+ image_hidden_states = self.model.vision_model(
+ pixel_values=pixel_values,
+ patch_attention_mask=patch_attention_mask,
+ ).last_hidden_state
+
+ # Modality projection & resampling
+ image_hidden_states = self.model.connector(
+ image_hidden_states, attention_mask=patch_attention_mask.view(pixel_values.size(0), -1)
+ )
+ pixel_values = None
+ pixel_attention_mask = None
+
+ model_inputs.update(
+ {
+ "position_ids": position_ids,
+ "past_key_values": past_key_values,
+ "use_cache": kwargs.get("use_cache"),
+ "attention_mask": attention_mask,
+ "pixel_values": pixel_values,
+ "pixel_attention_mask": pixel_attention_mask,
+ "image_hidden_states": image_hidden_states,
+ "token_idx": token_idx,
+ }
+ )
+ return model_inputs
diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py
index ce7d3cc283..9b3bdd6388 100755
--- a/optimum/habana/transformers/models/llama/modeling_llama.py
+++ b/optimum/habana/transformers/models/llama/modeling_llama.py
@@ -760,7 +760,7 @@ def attention_all_reduce(self, attn_output):
def post_attn_forward(self, attn_output):
if hasattr(self.o_proj, "post_all_reduce"):
- self.o_proj.post_all_reduce(attn_output)
+ return self.o_proj.post_all_reduce(attn_output)
return attn_output
diff --git a/optimum/habana/transformers/models/mllama/__init__.py b/optimum/habana/transformers/models/mllama/__init__.py
new file mode 100644
index 0000000000..198f1cc2aa
--- /dev/null
+++ b/optimum/habana/transformers/models/mllama/__init__.py
@@ -0,0 +1,10 @@
+from .modeling_mllama import (
+ GaudiMllamaCrossAttentionDecoderLayer,
+ GaudiMllamaForCausalLM,
+ GaudiMllamaForConditionalGeneration,
+ GaudiMllamaSelfAttentionDecoderLayer,
+ GaudiMllamaTextCrossAttention,
+ GaudiMllamaTextModel,
+ GaudiMllamaTextSelfAttention,
+ GaudiMllamaVisionModel,
+)
diff --git a/optimum/habana/transformers/models/mllama/modeling_mllama.py b/optimum/habana/transformers/models/mllama/modeling_mllama.py
new file mode 100644
index 0000000000..e5c7ced0d4
--- /dev/null
+++ b/optimum/habana/transformers/models/mllama/modeling_mllama.py
@@ -0,0 +1,1157 @@
+# coding=utf-8
+# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch Mllama model."""
+
+import math
+import os
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import CrossEntropyLoss
+from transformers.cache_utils import Cache
+from transformers.modeling_attn_mask_utils import AttentionMaskConverter
+from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, CausalLMOutputWithPast
+from transformers.models.mllama.configuration_mllama import MllamaConfig, MllamaTextConfig
+from transformers.models.mllama.modeling_mllama import (
+ MllamaCrossAttentionDecoderLayer,
+ MllamaForCausalLM,
+ MllamaForConditionalGeneration,
+ MllamaSelfAttentionDecoderLayer,
+ MllamaTextCrossAttention,
+ MllamaTextModel,
+ MllamaTextSelfAttention,
+ MllamaVisionModel,
+ _prepare_4d_causal_attention_mask_with_cache_position,
+ _prepare_aspect_ratio_attention_mask,
+ apply_rotary_pos_emb,
+ repeat_kv,
+)
+from transformers.utils import (
+ logging,
+)
+
+from ...modeling_attn_mask_utils import (
+ _gaudi_prepare_4d_causal_attention_mask,
+)
+
+
+logger = logging.get_logger(__name__)
+
+try:
+ from habana_frameworks.torch.hpex.kernels import FusedSDPA
+except ImportError:
+ print("Not using HPU fused scaled dot-product attention kernel.")
+ FusedSDPA = None
+
+
+class ModuleFusedSDPA(torch.nn.Module):
+ def __init__(self, fusedSDPA):
+ super().__init__()
+ self._hpu_kernel_fsdpa = fusedSDPA
+
+ def forward(self, query, key, value, attn_mask, dropout_p, is_casual, scale):
+ return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_casual, scale)
+
+
+def _prepare_cross_attention_mask(
+ cross_attention_mask: torch.Tensor,
+ num_vision_tokens: int,
+ dtype: str,
+ token_idx: Optional[torch.Tensor] = None,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Copied from _prepare_cross_attention_mask: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/mllama/modeling_mllama.py#L99
+ The only differences are:
+ - if there's pading in cross_attention_mask in the right. do not masked it, or else it will impact softmax in crossattention
+ """
+ # reshape so it can be used by attn module
+ batch_size, text_total_length, *_ = cross_attention_mask.shape
+ cross_attention_mask = cross_attention_mask.repeat_interleave(num_vision_tokens, dim=3)
+ cross_attention_mask = cross_attention_mask.view(batch_size, text_total_length, -1)
+ cross_attention_mask = cross_attention_mask.unsqueeze(1)
+
+ # invert the mask
+ inverted_cross_attn_mask = (1.0 - cross_attention_mask).to(dtype)
+ cross_attention_mask = inverted_cross_attn_mask.masked_fill(
+ inverted_cross_attn_mask.to(torch.bool), torch.finfo(dtype).min
+ )
+
+ # apply full-row bias, which return 4D tensor of shape [B, H, S1, 1] where value is 0 if the a full row in cross attn mask's
+ # last dimension contains negative infinity values, otherwise it's 1
+ negative_inf_value = torch.finfo(dtype).min
+ full_text_row_masked_out_mask = (
+ (cross_attention_mask != negative_inf_value).any(dim=-1).type_as(cross_attention_mask)[..., None]
+ )
+ if token_idx is not None:
+ full_text_row_masked_out_mask2 = full_text_row_masked_out_mask.clone()
+ full_text_row_masked_out_mask2[:, :, token_idx:, :] = 1
+ cross_attention_mask *= full_text_row_masked_out_mask2
+ else:
+ cross_attention_mask *= full_text_row_masked_out_mask
+
+ return cross_attention_mask, full_text_row_masked_out_mask
+
+
+class GaudiMllamaTextCrossAttention(MllamaTextCrossAttention):
+ def __init__(self, config: Optional[MllamaTextConfig] = None, layer_idx: Optional[int] = None):
+ super().__init__(config, layer_idx)
+ self.fused_scaled_dot_product_attention = ModuleFusedSDPA(FusedSDPA) if FusedSDPA else None
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ cross_attention_states: Optional[torch.Tensor] = None,
+ past_key_value: Optional[Cache] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ use_cache: bool = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ token_idx: Optional[torch.Tensor] = None,
+ use_flash_attention: Optional[bool] = False,
+ flash_attention_recompute: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ """
+ Copied from MllamaTextCrossAttention::forward: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/mllama/modeling_mllama.py#L512
+ The only differences are:
+ - add token_idx support
+ - add support if past_key_value is not Cache
+ - cache position is None
+ - add use_flash_attention and flash_attention_recompute
+ """
+ """Input shape: Batch x Time x Channel"""
+ bsz, q_len, _ = hidden_states.size()
+ query_states = self.q_proj(hidden_states)
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ query_states = self.q_norm(query_states)
+
+ if cross_attention_states is not None:
+ key_states = self.k_proj(cross_attention_states)
+ value_states = self.v_proj(cross_attention_states)
+ key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ if not (FusedSDPA and use_flash_attention):
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+ key_states = self.k_norm(key_states)
+ if past_key_value is not None:
+ # if we have a new image + new tokens, we only computed key_states on that new image
+ # we still update the cross key states, past_image, new_image. And use it!
+ if isinstance(past_key_value, Cache):
+ key_states, value_states = past_key_value.update(
+ key_states, value_states, self.layer_idx, {"cache_position": cache_position}
+ )
+ else:
+ if token_idx is not None:
+ past_key_value[0].index_copy_(2, token_idx - 1, key_states)
+ past_key_value[1].index_copy_(2, token_idx - 1, value_states)
+ key_states = past_key_value[0]
+ value_states = past_key_value[1]
+ else:
+ key_states = torch.cat((past_key_value[0], key_states), dim=2)
+ value_states = torch.cat((past_key_value[1], value_states), dim=2)
+ if use_cache and not isinstance(past_key_value, Cache):
+ past_key_value = [key_states, value_states]
+ elif not isinstance(past_key_value, Cache) and past_key_value is not None:
+ key_states, value_states = (past_key_value[0], past_key_value[1])
+ elif cache_position is not None and cache_position[0] != 0:
+ key_states, value_states = (
+ past_key_value.key_cache[self.layer_idx],
+ past_key_value.value_cache[self.layer_idx],
+ )
+ else:
+ raise ValueError(
+ "Cross attention layer can't find neither `cross_attn_states` nor cached values for key/values!"
+ )
+
+ if FusedSDPA and use_flash_attention:
+ import habana_frameworks.torch.hpu as ht
+
+ if q_len == 1:
+ # next token
+ use_recompute = True if os.getenv("QUANT_CONFIG", "") else False
+ with ht.sdp_kernel(enable_recompute=use_recompute):
+ attn_output = self.fused_scaled_dot_product_attention(
+ query_states, key_states, value_states, attention_mask, 0.0, False, None
+ )
+ else:
+ with ht.sdp_kernel(enable_recompute=flash_attention_recompute):
+ attn_output = self.fused_scaled_dot_product_attention(
+ query_states, key_states, value_states, attention_mask, 0.0, False, None
+ )
+ else:
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
+
+ if attention_mask is not None: # no matter the length, we just slice it
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.reshape(bsz, q_len, -1)
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+class GaudiMllamaTextSelfAttention(MllamaTextSelfAttention):
+ def __init__(self, config: Optional[MllamaTextConfig] = None, layer_idx: Optional[int] = None):
+ super().__init__(config, layer_idx)
+ self.fused_scaled_dot_product_attention = ModuleFusedSDPA(FusedSDPA) if FusedSDPA else None
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor,
+ position_embeddings: torch.Tensor,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ past_key_value=None,
+ cache_position=None,
+ token_idx: Optional[torch.Tensor] = None,
+ use_flash_attention: Optional[bool] = False,
+ flash_attention_recompute: Optional[bool] = False,
+ **kwargs,
+ ):
+ """
+ Copied from MllamaTextSelfAttention::forward: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/mllama/modeling_mllama.py#L733
+ The only differences are:
+ - add token_idx support
+ - add support if past_key_value is not Cache
+ - add use_flash_attention and flash_attention_recompute
+ """
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ if isinstance(past_key_value, Cache):
+ key_states, value_states = past_key_value.update(
+ key_states, value_states, self.layer_idx, cache_kwargs
+ )
+ else:
+ if token_idx is not None:
+ past_key_value[0].index_copy_(2, token_idx - 1, key_states)
+ past_key_value[1].index_copy_(2, token_idx - 1, value_states)
+ key_states = past_key_value[0]
+ value_states = past_key_value[1]
+ else:
+ key_states = torch.cat((past_key_value[0], key_states), dim=2)
+ value_states = torch.cat((past_key_value[1], value_states), dim=2)
+ if use_cache and not isinstance(past_key_value, Cache):
+ past_key_value = [key_states, value_states]
+
+ if FusedSDPA and use_flash_attention:
+ import habana_frameworks.torch.hpu as ht
+
+ if q_len == 1:
+ # next token
+ use_recompute = True if os.getenv("QUANT_CONFIG", "") else False
+ with ht.sdp_kernel(enable_recompute=use_recompute):
+ attn_output = self.fused_scaled_dot_product_attention(
+ query_states, key_states, value_states, attention_mask, 0.0, False, None
+ )
+ else:
+ with ht.sdp_kernel(enable_recompute=flash_attention_recompute):
+ attn_output = self.fused_scaled_dot_product_attention(
+ query_states, key_states, value_states, attention_mask, 0.0, False, None
+ )
+ else:
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
+
+ if attention_mask is not None: # no matter the length, we just slice it
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.view(bsz, q_len, -1)
+
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+# Modified from transformers.models.llama.modeling_llama.LlamaDecoderLayer
+class GaudiMllamaSelfAttentionDecoderLayer(MllamaSelfAttentionDecoderLayer):
+ def __init__(self, config: MllamaTextConfig, layer_idx: int) -> None:
+ super(GaudiMllamaSelfAttentionDecoderLayer, self).__init__(config, layer_idx)
+ self.self_attn = GaudiMllamaTextSelfAttention(config, layer_idx=layer_idx)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ cross_attention_states: Optional[torch.Tensor] = None,
+ cross_attention_mask: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
+ token_idx: Optional[torch.Tensor] = None,
+ use_flash_attention: Optional[bool] = False,
+ flash_attention_recompute: Optional[bool] = False,
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ """
+ Copied from MllamaSelfAttentionDecoderLayer::forward: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/mllama/modeling_mllama.py#L904
+ The only differences are:
+ - add token_idx input
+ - add use_flash_attention and flash_attention_recompute
+ """
+ residual = hidden_states
+
+ hidden_states = self.input_layernorm(hidden_states)
+
+ # Self Attention
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ token_idx=token_idx,
+ use_flash_attention=use_flash_attention,
+ flash_attention_recompute=flash_attention_recompute,
+ )
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ if use_cache:
+ outputs += (present_key_value,)
+
+ return outputs
+
+
+class GaudiMllamaCrossAttentionDecoderLayer(MllamaCrossAttentionDecoderLayer):
+ def __init__(self, config: MllamaTextConfig, layer_idx: int) -> None:
+ super(GaudiMllamaCrossAttentionDecoderLayer, self).__init__(config, layer_idx)
+ self.cross_attn = GaudiMllamaTextCrossAttention(config, layer_idx=layer_idx)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ cross_attention_states: torch.Tensor,
+ cross_attention_mask: torch.Tensor,
+ attention_mask: torch.Tensor,
+ full_text_row_masked_out_mask: Tuple[torch.Tensor, torch.Tensor],
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[torch.Tensor] = None,
+ token_idx: Optional[torch.Tensor] = None,
+ use_flash_attention: Optional[bool] = False,
+ flash_attention_recompute: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor]:
+ """
+ Copied from MllamaCrossAttentionDecoderLayer::forward: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/mllama/modeling_mllama.py#L989
+ The only differences are:
+ - add token_idx support
+ - pass use_cache to cross_attn
+ - add use_flash_attention and flash_attention_recompute
+ """
+ residual = hidden_states
+ hidden_states = self.input_layernorm(hidden_states)
+
+ hidden_states, attn_weights, past_key_value = self.cross_attn(
+ hidden_states=hidden_states,
+ attention_mask=cross_attention_mask,
+ cross_attention_states=cross_attention_states,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ cache_position=cache_position,
+ use_cache=use_cache,
+ token_idx=token_idx,
+ use_flash_attention=use_flash_attention,
+ flash_attention_recompute=flash_attention_recompute,
+ )
+ hidden_states = residual + self.cross_attn_attn_gate.tanh() * hidden_states
+
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ if full_text_row_masked_out_mask is not None:
+ hidden_states = full_text_row_masked_out_mask[:, 0] * hidden_states # type: ignore
+ hidden_states = residual + self.cross_attn_mlp_gate.tanh() * hidden_states
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (attn_weights,)
+
+ if use_cache:
+ outputs += (past_key_value,)
+
+ return outputs
+
+
+class GaudiMllamaTextModel(MllamaTextModel):
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ cross_attention_states: Optional[torch.FloatTensor] = None,
+ cross_attention_mask: Optional[torch.Tensor] = None,
+ full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ token_idx: Optional[torch.Tensor] = None,
+ use_flash_attention: Optional[bool] = False,
+ flash_attention_recompute: Optional[bool] = False,
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
+ """
+ Copied from MllamaTextModel::forward: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/mllama/modeling_mllama.py#L1617
+ The only differences are:
+ - add token_idx support
+ - add support if past_key_value is not Cache
+ - add use_flash_attention and flash_attention_recompute
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError(
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
+ )
+
+ if self.gradient_checkpointing and self.training and use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
+ )
+ use_cache = False
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ hidden_states = inputs_embeds
+ if isinstance(past_key_values, Cache):
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ else:
+ past_seen_tokens = past_key_values[0][0].shape[2] if past_key_values is not None else 0
+ ignore_cache_position = True # Ignoring cache position for HPU, or else hpu graph may has issue
+ if ignore_cache_position is False:
+ if cache_position is None:
+ cache_position = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+ causal_mask = self._update_causal_mask(
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
+ )
+ else:
+ if position_ids is None:
+ position_ids = torch.arange(
+ past_seen_tokens,
+ inputs_embeds.shape[1] + past_seen_tokens,
+ dtype=torch.long,
+ device=inputs_embeds.device,
+ )
+ position_ids = position_ids.unsqueeze(0)
+ cache_position = None
+ causal_mask = _gaudi_prepare_4d_causal_attention_mask(
+ attention_mask,
+ input_ids.shape,
+ inputs_embeds,
+ past_seen_tokens,
+ )
+
+ # create position embeddings to be shared across the decoder layers
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ next_decoder_cache = None if isinstance(past_key_values, Cache) else ()
+
+ for idx, decoder_layer in enumerate(self.layers):
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ # For text-only path we should skip cross attention layers.
+ # Let's check if the layer is cross attention layer and if we have cross attention states
+ # or cached cross attention states.
+ is_cross_attention_layer = idx in self.cross_attention_layers
+ is_cross_attention_cache_empty = past_key_values is None or (
+ past_key_values is not None and past_key_values.get_seq_length(idx) == 0
+ if isinstance(past_key_values, Cache)
+ else False
+ )
+
+ if is_cross_attention_layer and cross_attention_states is None and is_cross_attention_cache_empty:
+ continue
+
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ decoder_layer.__call__,
+ hidden_states,
+ cross_attention_states,
+ cross_attention_mask,
+ causal_mask,
+ full_text_row_masked_out_mask,
+ position_ids,
+ past_key_values,
+ output_attentions,
+ use_cache,
+ cache_position,
+ position_embeddings,
+ )
+ else:
+ if isinstance(past_key_values, Cache):
+ past_key_value = past_key_values
+ else:
+ past_key_value = None if past_key_values is None else past_key_values[idx]
+ layer_outputs = decoder_layer(
+ hidden_states,
+ cross_attention_states=cross_attention_states,
+ cross_attention_mask=cross_attention_mask,
+ attention_mask=causal_mask,
+ full_text_row_masked_out_mask=full_text_row_masked_out_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ token_idx=token_idx,
+ use_flash_attention=use_flash_attention,
+ flash_attention_recompute=flash_attention_recompute,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ if isinstance(past_key_values, Cache):
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
+ else:
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ hidden_states = self.norm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ next_cache = next_decoder_cache if use_cache else None
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+ def _update_causal_mask(
+ self,
+ attention_mask: torch.Tensor,
+ input_tensor: torch.Tensor,
+ cache_position: torch.Tensor,
+ past_key_values: Cache,
+ output_attentions: bool,
+ ):
+ """
+ Copied from MllamaTextModel::_update_causal_mask: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/mllama/modeling_mllama.py#L1768
+ The only differences are:
+ - add support if past_key_value is not Cache
+ """
+ if self.config._attn_implementation == "flash_attention_2":
+ if attention_mask is not None and 0.0 in attention_mask:
+ return attention_mask
+ return None
+
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
+ # to infer the attention mask.
+ if isinstance(past_key_values, Cache):
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ else:
+ past_seen_tokens = past_key_values[0][0].shape[2] if past_key_values is not None else 0
+
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
+ # TODO: we have only SDPA currently and there's a bug when attn-bias is passed. Need to add eager attn and return the line
+ # self.config._attn_implementation == "sdpa" and
+ if self.config._attn_implementation == "sdpa" and not output_attentions:
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
+ attention_mask,
+ inputs_embeds=input_tensor,
+ past_key_values_length=past_seen_tokens,
+ is_training=self.training,
+ ):
+ return None
+
+ dtype, device = input_tensor.dtype, input_tensor.device
+ min_dtype = torch.finfo(dtype).min
+ sequence_length = input_tensor.shape[1]
+ target_length = (
+ attention_mask.shape[-1]
+ if isinstance(attention_mask, torch.Tensor)
+ else past_seen_tokens + sequence_length + 1
+ )
+
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
+ causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask,
+ sequence_length=sequence_length,
+ target_length=target_length,
+ dtype=dtype,
+ device=device,
+ min_dtype=min_dtype,
+ cache_position=cache_position,
+ batch_size=input_tensor.shape[0],
+ )
+
+ if (
+ self.config._attn_implementation == "sdpa"
+ and attention_mask is not None
+ and attention_mask.device.type == "cuda"
+ and not output_attentions
+ ):
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
+ # Details: https://github.com/pytorch/pytorch/issues/110213
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
+
+ return causal_mask
+
+
+class GaudiMllamaForCausalLM(MllamaForCausalLM):
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ cross_attention_states: Optional[torch.LongTensor] = None,
+ cross_attention_mask: Optional[torch.LongTensor] = None,
+ full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ num_logits_to_keep: int = 0,
+ token_idx: Optional[torch.Tensor] = None,
+ use_flash_attention: Optional[bool] = False,
+ flash_attention_recompute: Optional[bool] = False,
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
+ """
+ Copied from MllamaForCausalLM::forward: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/mllama/modeling_mllama.py#L1871
+ The only differences are:
+ - add token_idx input
+ - add logits handle if token_idx is not None
+ - add use_flash_attention and flash_attention_recompute
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs = self.model(
+ input_ids=input_ids,
+ cross_attention_states=cross_attention_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ cross_attention_mask=cross_attention_mask,
+ full_text_row_masked_out_mask=full_text_row_masked_out_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ cache_position=cache_position,
+ token_idx=token_idx,
+ use_flash_attention=use_flash_attention,
+ flash_attention_recompute=flash_attention_recompute,
+ )
+
+ hidden_states = outputs[0]
+
+ if token_idx is None and num_logits_to_keep != 0:
+ logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float()
+ else:
+ logits = self.lm_head(hidden_states).float()
+
+ loss = None
+ if labels is not None:
+ # Upcast to float if we need to compute the loss to avoid potential precision issues
+ logits = logits.float()
+ # Shift so that tokens < n predict n
+ shift_logits = logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ # Flatten the tokens
+ loss_fct = CrossEntropyLoss()
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
+ shift_labels = shift_labels.view(-1)
+ # Enable model parallelism
+ shift_labels = shift_labels.to(shift_logits.device)
+ loss = loss_fct(shift_logits, shift_labels)
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return (loss,) + output if loss is not None else output
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+class GaudiMllamaForConditionalGeneration(MllamaForConditionalGeneration):
+ def __init__(self, config: MllamaConfig):
+ # sdpa is better for vision model in HPU
+ config._attn_implementation = "sdpa"
+ super(GaudiMllamaForConditionalGeneration, self).__init__(config)
+
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ aspect_ratio_mask: Optional[torch.Tensor] = None,
+ aspect_ratio_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ cross_attention_mask: Optional[torch.Tensor] = None,
+ cross_attention_states: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ num_logits_to_keep: int = 0,
+ token_idx: Optional[torch.Tensor] = None,
+ use_flash_attention: Optional[bool] = False,
+ flash_attention_recompute: Optional[bool] = False,
+ **kwargs,
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
+ """
+ Copied from MllamaForConditionalGeneration::forward: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/mllama/modeling_mllama.py#L2077
+ The only differences are:
+ - add token_idx input
+ - add use_flash_attention and flash_attention_recompute
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError(
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
+ )
+
+ if pixel_values is not None and inputs_embeds is not None:
+ raise ValueError(
+ "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
+ )
+
+ if pixel_values is not None and cross_attention_states is not None:
+ raise ValueError("`pixel_values` and `cross_attention_states` cannot be provided simultaneously")
+
+ if pixel_values is not None:
+ if aspect_ratio_ids is None:
+ raise ValueError("`aspect_ratio_ids` must be provided if `pixel_values` is provided")
+ # get vision tokens from vision model
+ vision_outputs = self.vision_model(
+ pixel_values=pixel_values,
+ aspect_ratio_ids=aspect_ratio_ids,
+ aspect_ratio_mask=aspect_ratio_mask,
+ output_hidden_states=output_hidden_states,
+ output_attentions=output_attentions,
+ return_dict=return_dict,
+ )
+ cross_attention_states = vision_outputs[0]
+ cross_attention_states = self.multi_modal_projector(cross_attention_states).reshape(
+ -1, cross_attention_states.shape[-2], self.hidden_size
+ )
+
+ if cross_attention_mask is not None:
+ cross_attention_mask, full_text_row_masked_out_mask = _prepare_cross_attention_mask(
+ cross_attention_mask,
+ num_vision_tokens=self.vision_model.num_patches,
+ dtype=self.dtype,
+ token_idx=token_idx,
+ )
+ else:
+ full_text_row_masked_out_mask = None
+
+ if cross_attention_mask is not None:
+ if cache_position is not None:
+ cross_attention_mask = cross_attention_mask[:, :, cache_position]
+ full_text_row_masked_out_mask = full_text_row_masked_out_mask[:, :, cache_position]
+ elif past_key_values is not None:
+ if token_idx is not None:
+ cross_attention_mask = torch.index_select(cross_attention_mask, -2, token_idx - 1)
+ full_text_row_masked_out_mask = torch.index_select(
+ full_text_row_masked_out_mask, -2, token_idx - 1
+ )
+ else:
+ cross_attention_mask = cross_attention_mask[:, :, -1:]
+ full_text_row_masked_out_mask = full_text_row_masked_out_mask[:, :, -1:]
+ outputs = self.language_model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ cross_attention_states=cross_attention_states,
+ cross_attention_mask=cross_attention_mask,
+ full_text_row_masked_out_mask=full_text_row_masked_out_mask,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ inputs_embeds=inputs_embeds,
+ labels=labels,
+ output_hidden_states=output_hidden_states,
+ output_attentions=output_attentions,
+ return_dict=return_dict,
+ cache_position=cache_position,
+ num_logits_to_keep=num_logits_to_keep,
+ token_idx=token_idx,
+ use_flash_attention=use_flash_attention,
+ flash_attention_recompute=flash_attention_recompute,
+ )
+
+ return outputs
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids=None,
+ inputs_embeds=None,
+ attention_mask=None,
+ position_ids=None,
+ pixel_values=None,
+ aspect_ratio_ids=None,
+ aspect_ratio_mask=None,
+ cross_attention_mask=None,
+ past_key_values=None,
+ use_cache=False,
+ cache_position=None,
+ num_logits_to_keep=None,
+ **kwargs,
+ ):
+ """
+ Copied from MllamaForConditionalGeneration::prepare_inputs_for_generation: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/mllama/modeling_mllama.py#L2208
+ The only differences are:
+ - add token_idx handling
+ - add bucket_internal handling
+ - add use_flash_attention and flash_attention_recompute
+ """
+ token_idx = kwargs.get("token_idx", None)
+ bucket_internal = kwargs.get("bucket_internal", None)
+ if past_key_values is not None:
+ if token_idx is not None:
+ input_ids = torch.index_select(input_ids, 1, token_idx - 1)
+ elif inputs_embeds is not None: # Exception 1
+ input_ids = input_ids[:, -cache_position.shape[0] :]
+ elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
+ input_ids = input_ids[:, cache_position]
+ elif bucket_internal and token_idx is not None:
+ # for the 1st token we can slice the inputs till token idx for the fwd pass.
+ input_ids = input_ids[:, :token_idx]
+ attention_mask = attention_mask[:, :token_idx]
+ if cross_attention_mask is not None:
+ cross_attention_mask = cross_attention_mask[:, :token_idx, ...]
+
+ # TODO: we have no attention_mask so this won't work, check if we really won't need attention mask and find another way
+ if attention_mask is not None and position_ids is None:
+ # create position_ids on the fly for batch generation
+ position_ids = attention_mask.long().cumsum(-1) - 1
+ position_ids.masked_fill_(attention_mask == 0, 1)
+ if past_key_values:
+ if token_idx is not None:
+ position_ids = torch.index_select(position_ids, 1, token_idx - 1)
+ else:
+ position_ids = position_ids[:, -input_ids.shape[1] :]
+
+ # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
+ position_ids = position_ids.clone(memory_format=torch.contiguous_format)
+
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
+ if inputs_embeds is not None and cache_position[0] == 0:
+ model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
+ else:
+ # The clone here is for the same reason as for `position_ids`.
+ model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
+
+ if num_logits_to_keep is not None:
+ model_inputs["num_logits_to_keep"] = num_logits_to_keep
+
+ # keep cache_position implementation as None for HPU
+ cache_position = None
+
+ model_inputs.update(
+ {
+ "position_ids": position_ids,
+ "cache_position": cache_position,
+ "past_key_values": past_key_values,
+ "use_cache": use_cache,
+ "attention_mask": attention_mask,
+ "cross_attention_mask": cross_attention_mask,
+ "token_idx": token_idx,
+ "use_flash_attention": kwargs.get("use_flash_attention"),
+ "flash_attention_recompute": kwargs.get("flash_attention_recompute"),
+ }
+ )
+
+ # If we're in pre-fill or cacheless decoding step, then we need pixel_values and aspect ratios
+ # to compute image hidden states, otherwise they are cached within each cross attn layer
+ if (input_ids == self.config.image_token_index).any():
+ model_inputs["pixel_values"] = pixel_values
+ model_inputs["aspect_ratio_ids"] = aspect_ratio_ids
+ model_inputs["aspect_ratio_mask"] = aspect_ratio_mask
+
+ return model_inputs
+
+ def _update_model_kwargs_for_generation(self, outputs, model_kwargs, is_encoder_decoder, **kwargs):
+ """
+ Copied from MllamaForConditionalGeneration::_update_model_kwargs_for_generation: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/mllama/modeling_mllama.py#L2274
+ The only differences are:
+ - add token_idx handling
+ """
+ cross_attention_mask_prev = model_kwargs.get("cross_attention_mask", None)
+ model_kwargs = super(MllamaForConditionalGeneration, self)._update_model_kwargs_for_generation(
+ outputs=outputs,
+ model_kwargs=model_kwargs,
+ is_encoder_decoder=is_encoder_decoder,
+ **kwargs,
+ )
+
+ # add cross-attn mask for new token
+ if cross_attention_mask_prev is not None:
+ token_idx = model_kwargs.get("token_idx", None)
+ if token_idx is not None:
+ mask = cross_attention_mask_prev[:, token_idx - 2 : token_idx - 1, ...]
+ cross_attention_mask_prev.index_copy_(1, token_idx - 1, mask)
+ model_kwargs["cross_attention_mask"] = cross_attention_mask_prev
+ else:
+ model_kwargs["cross_attention_mask"] = torch.cat(
+ [cross_attention_mask_prev, cross_attention_mask_prev[:, -1:, ...]], dim=1
+ )
+ return model_kwargs
+
+
+class GaudiMllamaVisionModel(MllamaVisionModel):
+ def forward(
+ self,
+ pixel_values: torch.Tensor,
+ aspect_ratio_ids: torch.Tensor,
+ aspect_ratio_mask: torch.Tensor,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[BaseModelOutput, Tuple[torch.Tensor, ...]]:
+ """
+ Copied from MllamaVisionModel::forward: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/mllama/modeling_mllama.py#L1425
+ The only differences are:
+ - optimize perf of stage "Collect intermediate layer outputs from encoder output"
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ batch_size, num_concurrent_media, num_tiles, num_channels, height, width = pixel_values.shape
+
+ pixel_values = pixel_values.reshape(batch_size * num_concurrent_media * num_tiles, num_channels, height, width)
+ aspect_ratio_ids = aspect_ratio_ids.reshape(batch_size * num_concurrent_media, -1)
+
+ # Patch embedding
+ patch_embeds = self.patch_embedding(pixel_values.to(self.dtype).to(self.device))
+ hidden_state = patch_embeds.flatten(2).transpose(1, 2)
+
+ # Tile embeddings
+ _, num_patches, dim = hidden_state.shape
+ hidden_state = hidden_state.reshape(batch_size * num_concurrent_media, num_tiles, -1, dim)
+ hidden_state = self.pre_tile_positional_embedding(hidden_state, aspect_ratio_ids)
+
+ # Add cls token
+ hidden_state = hidden_state.reshape(batch_size * num_concurrent_media * num_tiles, num_patches, dim)
+ hidden_state = self.apply_class_embedding(hidden_state)
+ num_patches += 1
+
+ # Position embeddings
+ hidden_state = hidden_state.reshape(batch_size * num_concurrent_media, num_tiles, num_patches, dim)
+ hidden_state = self.gated_positional_embedding(hidden_state, aspect_ratio_ids)
+
+ hidden_state = self.layernorm_pre(hidden_state)
+
+ # Compute the number of tokens to pad
+ num_padding_patches = (8 - (hidden_state.shape[-2] % 8)) % 8
+ # Compute padding tuple for pad function
+ padding = (0, 0, 0, num_padding_patches) # (pad_left, pad_right, pad_left for dim -2, pad_right for dim -2)
+ # Pad the tensor
+ hidden_state = F.pad(hidden_state, padding, mode="constant", value=0)
+ slice_index = -num_padding_patches if num_padding_patches > 0 else None
+
+ # Prepare attention mask
+ attention_mask = aspect_ratio_mask.reshape(batch_size * num_concurrent_media, -1)
+ attention_mask = _prepare_aspect_ratio_attention_mask(
+ aspect_ratio_mask=attention_mask,
+ num_patches=self.num_patches,
+ target_length=hidden_state.shape[2],
+ dtype=self.dtype,
+ )
+
+ # Apply encoder
+ hidden_state = hidden_state.view(batch_size * num_concurrent_media, -1, dim)
+ output = self.transformer(
+ hidden_state,
+ attention_mask=attention_mask,
+ output_hidden_states=True,
+ output_attentions=output_attentions,
+ )
+ hidden_state = output[0]
+
+ hidden_state = self.layernorm_post(hidden_state)
+
+ # Apply global encoder
+ hidden_state = hidden_state.reshape(
+ batch_size * num_concurrent_media, num_tiles, num_patches + num_padding_patches, dim
+ )
+ hidden_state = self.post_tile_positional_embedding(hidden_state, aspect_ratio_ids)
+ hidden_state = hidden_state.reshape(
+ batch_size * num_concurrent_media, num_tiles * (num_patches + num_padding_patches), dim
+ )
+ global_output = self.global_transformer(
+ hidden_state,
+ attention_mask=attention_mask,
+ output_hidden_states=output_hidden_states,
+ output_attentions=output_attentions,
+ )
+ hidden_state = global_output[0]
+
+ # Remove padding form hidden state
+ hidden_state = hidden_state.reshape(
+ batch_size * num_concurrent_media, num_tiles, num_patches + num_padding_patches, dim
+ )
+ hidden_state = hidden_state[:, :, :slice_index]
+ hidden_state = hidden_state.reshape(batch_size, num_concurrent_media, num_tiles, num_patches, dim)
+
+ # Collect intermediate layer outputs from encoder output
+ all_intermediate_hidden_states = output[1]
+ intermediate_hidden_states = [
+ hidden_state
+ for idx, hidden_state in enumerate(all_intermediate_hidden_states)
+ if idx in self.intermediate_layers_indices
+ ]
+ intermediate_hidden_states = torch.stack(intermediate_hidden_states, dim=-1)
+
+ """
+ intermediate_hidden_states = torch.stack(all_intermediate_hidden_states, dim=-1)
+ intermediate_hidden_states = intermediate_hidden_states[..., self.intermediate_layers_indices]
+ """
+
+ # Remove padding from intermediate hidden states
+ intermediate_hidden_states = intermediate_hidden_states.reshape(
+ batch_size * num_concurrent_media, num_tiles, num_patches + num_padding_patches, -1
+ )
+ intermediate_hidden_states = intermediate_hidden_states[:, :, :slice_index]
+ intermediate_hidden_states = intermediate_hidden_states.reshape(
+ batch_size, num_concurrent_media, num_tiles, num_patches, -1
+ )
+
+ # Concatenate final hidden state and intermediate hidden states
+ hidden_state = torch.cat([hidden_state, intermediate_hidden_states], dim=-1)
+
+ if output_hidden_states:
+ hidden_states = tuple(all_intermediate_hidden_states) + tuple(global_output[1])
+ else:
+ hidden_states = None
+
+ if output_attentions:
+ # global transformer in contrast to `self.transformer` doesn't always return hidden states so we might go index out-of-range
+ global_attn = tuple(global_output[2]) if output_hidden_states else tuple(global_output[1])
+ attentions = tuple(output[2]) + global_attn
+ else:
+ attentions = None
+
+ if not return_dict:
+ return tuple(v for v in [hidden_state, hidden_states, attentions] if v is not None)
+
+ return BaseModelOutput(
+ last_hidden_state=hidden_state,
+ hidden_states=hidden_states,
+ attentions=attentions,
+ )
diff --git a/optimum/habana/transformers/models/modeling_all_models.py b/optimum/habana/transformers/models/modeling_all_models.py
index a769220242..90aa2d5e0f 100644
--- a/optimum/habana/transformers/models/modeling_all_models.py
+++ b/optimum/habana/transformers/models/modeling_all_models.py
@@ -164,7 +164,5 @@ def all_reduce(self, input):
dist.inference_all_reduce(input, group=self.mp_group)
def post_all_reduce(self, input):
- # inplace addition needed for correct results
- if self.bias is not None:
- input += self.bias
- return input
+ output = input + self.bias if (self.bias is not None) else input
+ return output
diff --git a/optimum/habana/transformers/models/qwen2/modeling_qwen2.py b/optimum/habana/transformers/models/qwen2/modeling_qwen2.py
index 7bd8ebcd9b..1484224695 100644
--- a/optimum/habana/transformers/models/qwen2/modeling_qwen2.py
+++ b/optimum/habana/transformers/models/qwen2/modeling_qwen2.py
@@ -198,9 +198,22 @@ def __init__(self, config: Qwen2Config, layer_idx: Optional[int] = None):
self.block_size = 4096
self.rotary_emb = GaudiRotaryEmbedding(config=self.config)
+ def get_k_proj_weight(self):
+ """4bit quantization in GPTQ replaces the k_proj.weight with qweight."""
+ if hasattr(self.k_proj, "qweight"):
+ return self.k_proj.qweight
+ return self.k_proj.weight
+
+ def get_k_proj_weight_dtype(self):
+ """4bit quantization in GPTQ replaces the k_proj.weight with qweight.
+ Scales tensor gets the weight dtype."""
+ if hasattr(self.k_proj, "qweight"):
+ return self.k_proj.scales.dtype
+ return self.k_proj.weight.dtype
+
def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len):
cache_shape = (batch_size, self.num_key_value_heads, max_seq_len, self.head_dim)
- device = self.k_proj.weight.device
+ device = self.get_k_proj_weight().device
dtype = self.config.torch_dtype
self.k_cache.allocate(inp_seq_len, dtype, device, cache_shape)
self.v_cache.allocate(inp_seq_len, dtype, device, cache_shape)
@@ -211,7 +224,7 @@ def update_sincos_cache(self, seq_len):
# reduce memory consumption and improve performance.
if seq_len > self.max_position_embeddings:
self.max_position_embeddings = seq_len
- _, _ = self.rotary_emb(self.k_proj.weight, seq_len=seq_len)
+ _, _ = self.rotary_emb(self.get_k_proj_weight(), seq_len=seq_len)
def reorder(self, tensor, beam_idx, dim_a, dim_b):
updated = tensor.index_select(0, beam_idx)
@@ -316,9 +329,11 @@ def pre_attn_forward(
past_key_value = (self.k_cache.get_shape(), self.v_cache.get_shape())
else:
if past_key_value is None:
- past_key = torch.zeros(key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device)
+ past_key = torch.zeros(
+ key_states.shape, dtype=self.get_k_proj_weight_dtype(), device=key_states.device
+ )
past_value = torch.zeros(
- key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device
+ key_states.shape, dtype=self.get_k_proj_weight_dtype(), device=key_states.device
)
past_key_value = [past_key, past_value]
key_states = self.k_cache.update(past_key_value[0], key_states, 2, token_idx, self.inp_seq_len)
@@ -419,7 +434,7 @@ def attention_all_reduce(self, attn_output):
def post_attn_forward(self, attn_output):
if hasattr(self.o_proj, "post_all_reduce"):
- self.o_proj.post_all_reduce(attn_output)
+ return self.o_proj.post_all_reduce(attn_output)
return attn_output
diff --git a/optimum/habana/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/optimum/habana/transformers/models/qwen2_moe/modeling_qwen2_moe.py
index fa1d8aae53..5b9da828cd 100755
--- a/optimum/habana/transformers/models/qwen2_moe/modeling_qwen2_moe.py
+++ b/optimum/habana/transformers/models/qwen2_moe/modeling_qwen2_moe.py
@@ -491,7 +491,7 @@ def attention_all_reduce(self, attn_output):
def post_attn_forward(self, attn_output):
if hasattr(self.o_proj, "post_all_reduce"):
- self.o_proj.post_all_reduce(attn_output)
+ return self.o_proj.post_all_reduce(attn_output)
return attn_output
diff --git a/optimum/habana/transformers/models/starcoder2/modeling_starcoder2.py b/optimum/habana/transformers/models/starcoder2/modeling_starcoder2.py
index b01a176368..00d9de7193 100644
--- a/optimum/habana/transformers/models/starcoder2/modeling_starcoder2.py
+++ b/optimum/habana/transformers/models/starcoder2/modeling_starcoder2.py
@@ -17,6 +17,7 @@
###############################################################################
import math
+import os
from typing import List, Optional, Tuple, Union
import torch
@@ -307,7 +308,8 @@ def pre_attn_forward(
if q_len == 1:
# next token
- with ht.sdp_kernel(enable_recompute=False):
+ use_recompute = True if os.getenv("QUANT_CONFIG", "") else False
+ with ht.sdp_kernel(enable_recompute=use_recompute):
attn_output = FusedSDPA.apply(
query_states, key_states, value_states, attention_mask, 0.0, False, None
)
@@ -374,7 +376,7 @@ def attention_all_reduce(self, attn_output):
def post_attn_forward(self, attn_output):
if hasattr(self.o_proj, "post_all_reduce"):
- self.o_proj.post_all_reduce(attn_output)
+ return self.o_proj.post_all_reduce(attn_output)
return attn_output
diff --git a/optimum/habana/transformers/models/xglm/__init__.py b/optimum/habana/transformers/models/xglm/__init__.py
new file mode 100644
index 0000000000..427d1b7a63
--- /dev/null
+++ b/optimum/habana/transformers/models/xglm/__init__.py
@@ -0,0 +1,6 @@
+from .modeling_xglm import (
+ GaudiXGLMForCausalLM,
+ gaudi_xglm_attention_forward,
+ gaudi_xglm_decoder_layer_forward,
+ gaudi_xglm_model_forward,
+)
diff --git a/optimum/habana/transformers/models/xglm/modeling_xglm.py b/optimum/habana/transformers/models/xglm/modeling_xglm.py
new file mode 100644
index 0000000000..ef5a16801a
--- /dev/null
+++ b/optimum/habana/transformers/models/xglm/modeling_xglm.py
@@ -0,0 +1,515 @@
+from typing import List, Optional, Tuple, Union
+
+import torch
+from torch import nn
+from torch.nn import CrossEntropyLoss
+from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
+from transformers.models.xglm.modeling_xglm import XGLMForCausalLM
+from transformers.utils import logging
+
+from ...modeling_attn_mask_utils import (
+ _gaudi_prepare_4d_causal_attention_mask,
+)
+
+
+logger = logging.get_logger(__name__)
+
+
+def gaudi_xglm_attention_forward(
+ self,
+ hidden_states: torch.Tensor,
+ key_value_states: Optional[torch.Tensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ layer_head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ token_idx: Optional[torch.Tensor] = None,
+) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ """
+ Copied from XGLMAttention.forward: https://github.com/huggingface/transformers/blob/v4.44.1/src/transformers/models/xglm/modeling_xglm.py
+ The only differences are:
+ - add new args token_idx
+ - optimize KV cache
+ """
+ # if key_value_states are provided this layer is used as a cross-attention layer
+ # for the decoder
+ is_cross_attention = key_value_states is not None
+
+ bsz, tgt_len, _ = hidden_states.size()
+
+ # get query proj
+ query_states = self.q_proj(hidden_states) * self.scaling
+ # get key, value proj
+ if is_cross_attention and past_key_value is not None:
+ # reuse k,v, cross_attentions
+ key_states = past_key_value[0]
+ value_states = past_key_value[1]
+ elif is_cross_attention:
+ # cross_attentions
+ key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
+ value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
+ elif past_key_value is not None:
+ # reuse k, v, self_attention
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+ if token_idx is not None:
+ past_key_value[0].index_copy_(2, token_idx - 1, key_states)
+ past_key_value[1].index_copy_(2, token_idx - 1, value_states)
+ key_states = past_key_value[0]
+ value_states = past_key_value[1]
+ else:
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
+ else:
+ # self_attention
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+
+ if self.is_decoder:
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
+ # Further calls to cross_attention layer can then reuse all cross-attention
+ # key/value_states (first "if" case)
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
+ past_key_value = (key_states, value_states)
+
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
+ key_states = key_states.view(*proj_shape)
+ value_states = value_states.view(*proj_shape)
+
+ src_len = key_states.size(1)
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
+
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
+ raise ValueError(
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+ f" {attn_weights.size()}"
+ )
+
+ if attention_mask is not None:
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len):
+ raise ValueError(
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
+ )
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
+ attn_weights = torch.max(
+ attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min, device=attn_weights.device)
+ )
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+ # upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437
+ if attn_weights.dtype == torch.float16:
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(torch.float16)
+ else:
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+ if layer_head_mask is not None:
+ if layer_head_mask.size() != (self.num_heads,):
+ raise ValueError(
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
+ f" {layer_head_mask.size()}"
+ )
+ attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+ if output_attentions:
+ # this operation is a bit awkward, but it's required to
+ # make sure that attn_weights keeps its gradient.
+ # In order to do so, attn_weights have to be reshaped
+ # twice and have to be reused in the following
+ attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
+ else:
+ attn_weights_reshaped = None
+
+ attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
+
+ attn_output = torch.bmm(attn_probs, value_states)
+
+ if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
+ attn_output = attn_output.transpose(1, 2)
+
+ # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
+ # partitioned aross GPUs when using tensor-parallelism.
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
+
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output, attn_weights_reshaped, past_key_value
+
+
+def gaudi_xglm_decoder_layer_forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ layer_head_mask: Optional[torch.Tensor] = None,
+ cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = True,
+ token_idx: Optional[torch.Tensor] = None,
+) -> torch.Tensor:
+ """
+ Copied from XGLMDecoderLayer.forward: https://github.com/huggingface/transformers/blob/v4.44.1/src/transformers/models/xglm/modeling_xglm.py
+ The only differences are:
+ - add new args token_idx
+ """
+ residual = hidden_states
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+
+ # Self Attention
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
+ # add present self-attn cache to positions 1,2 of present_key_value tuple
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
+ hidden_states=hidden_states,
+ past_key_value=self_attn_past_key_value,
+ attention_mask=attention_mask,
+ layer_head_mask=layer_head_mask,
+ output_attentions=output_attentions,
+ token_idx=token_idx,
+ )
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = residual + hidden_states
+
+ # Cross-Attention Block
+ cross_attn_present_key_value = None
+ cross_attn_weights = None
+ if encoder_hidden_states is not None:
+ residual = hidden_states
+ hidden_states = self.encoder_attn_layer_norm(hidden_states)
+
+ # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
+ cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
+ hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
+ hidden_states=hidden_states,
+ key_value_states=encoder_hidden_states,
+ attention_mask=encoder_attention_mask,
+ layer_head_mask=cross_attn_layer_head_mask,
+ past_key_value=cross_attn_past_key_value,
+ output_attentions=output_attentions,
+ )
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = residual + hidden_states
+
+ # add cross-attn to positions 3,4 of present_key_value tuple
+ present_key_value = present_key_value + cross_attn_present_key_value
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.final_layer_norm(hidden_states)
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
+ hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
+ hidden_states = self.fc2(hidden_states)
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights, cross_attn_weights)
+
+ if use_cache:
+ outputs += (present_key_value,)
+
+ return outputs
+
+
+def gaudi_xglm_model_forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ token_idx: Optional[torch.Tensor] = None,
+) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
+ """
+ Copied from XGLMModel.forward: https://github.com/huggingface/transformers/blob/v4.44.1/src/transformers/models/xglm/modeling_xglm.py
+ The only differences are:
+ - add new args token_idx
+ - replace _prepare_4d_causal_attention_mask with _gaudi_prepare_4d_causal_attention_mask
+ """
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # retrieve input_ids and inputs_embeds
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
+ input_shape = input_ids.size()
+ input_ids = input_ids.view(-1, input_shape[-1])
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
+
+ if position_ids is None:
+ position_ids = torch.arange(
+ past_key_values_length,
+ input_shape[-1] + past_key_values_length,
+ dtype=torch.long,
+ device=input_ids.device if input_ids is not None else inputs_embeds.device,
+ )
+ position_ids = position_ids.unsqueeze(0)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ attention_mask = _gaudi_prepare_4d_causal_attention_mask(
+ attention_mask, input_shape, inputs_embeds, past_key_values_length
+ )
+
+ # expand encoder attention mask
+ if encoder_hidden_states is not None and encoder_attention_mask is not None:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ encoder_attention_mask = _gaudi_prepare_4d_causal_attention_mask(
+ encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
+ )
+
+ hidden_states = inputs_embeds + self.embed_positions(position_ids, past_key_values_length)
+ hidden_states = nn.functional.dropout(hidden_states, p=float(self.dropout), training=self.training)
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache = True` is incompatible with gradient checkpointing`. Setting `use_cache =" " False`..."
+ )
+ use_cache = False
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
+ next_decoder_cache = () if use_cache else None
+
+ # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
+ for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
+ if attn_mask is not None:
+ if attn_mask.size()[0] != len(self.layers):
+ raise ValueError(
+ f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
+ f" {head_mask.size()[0]}."
+ )
+ for idx, decoder_layer in enumerate(self.layers):
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+ if self.training:
+ dropout_probability = torch.rand([])
+ if dropout_probability < self.layerdrop:
+ continue
+
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
+
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ decoder_layer.__call__,
+ hidden_states,
+ attention_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ head_mask[idx] if head_mask is not None else None,
+ cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
+ None,
+ output_attentions,
+ use_cache,
+ )
+ else:
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
+ cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None),
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ token_idx=token_idx,
+ )
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ if encoder_hidden_states is not None:
+ all_cross_attentions += (layer_outputs[2],)
+
+ hidden_states = self.layer_norm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ next_cache = next_decoder_cache if use_cache else None
+ if not return_dict:
+ return tuple(
+ v
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions]
+ if v is not None
+ )
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ cross_attentions=all_cross_attentions,
+ )
+
+
+class GaudiXGLMForCausalLM(XGLMForCausalLM):
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ token_idx: Optional[torch.Tensor] = None,
+ ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
+ """
+ Inherits from XGLMForCausalLM: https://github.com/huggingface/transformers/blob/v4.44.1/src/transformers/models/xglm/modeling_xglm.py
+ The only differences are:
+ - add new args token_idx
+ """
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ head_mask=head_mask,
+ cross_attn_head_mask=cross_attn_head_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ token_idx=token_idx,
+ )
+
+ logits = self.lm_head(outputs[0])
+
+ loss = None
+ if labels is not None:
+ # shift labels and add a pad token to the end
+ shift_labels = labels.new_zeros(labels.shape)
+ shift_labels[:, :-1] = labels[:, 1:].clone()
+ shift_labels[:, -1] = self.config.pad_token_id
+
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return (loss,) + output if loss is not None else output
+
+ return CausalLMOutputWithCrossAttentions(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ cross_attentions=outputs.cross_attentions,
+ )
+
+ def prepare_inputs_for_generation(
+ self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, **kwargs
+ ):
+ """
+ Inherits from XGLMForCausalLM: https://github.com/huggingface/transformers/blob/v4.44.1/src/transformers/models/xglm/modeling_xglm.py
+ The only differences are:
+ - add new args token_idx
+ - add token_idx into model_inputs
+ - from step2 when enable KV cache, slice next_input_ids from input_ids base on the token_idx
+ - from step2 when enable KV cache, slice next_position_ids from position_ids base on the token_idx
+ """
+ token_idx = kwargs.get("token_idx", None)
+ if past_key_values is not None:
+ if token_idx is None:
+ past_length = past_key_values[0][0].shape[2]
+
+ # Some generation methods already pass only the last input ID
+ if input_ids.shape[1] > past_length:
+ remove_prefix_length = past_length
+ else:
+ # Default to old behavior: keep only final ID
+ remove_prefix_length = input_ids.shape[1] - 1
+
+ input_ids = input_ids[:, remove_prefix_length:]
+ else:
+ input_ids = torch.index_select(input_ids, 1, token_idx - 1)
+
+ position_ids = kwargs.get("position_ids", None)
+ if attention_mask is not None and position_ids is None:
+ # create position_ids on the fly for batch generation
+ position_ids = attention_mask.long().cumsum(-1) - 1
+ position_ids.masked_fill_(attention_mask == 0, 1)
+ if past_key_values:
+ if token_idx is not None:
+ position_ids = torch.index_select(position_ids, 1, token_idx - 1)
+ else:
+ position_ids = position_ids[:, -input_ids.shape[1] :]
+ else:
+ position_ids = None
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
+ if attention_mask is None:
+ attention_mask = input_ids.new_ones(input_ids.shape)
+ # first step, decoder_cached_states are empty
+ return {
+ "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
+ "attention_mask": attention_mask,
+ "position_ids": position_ids,
+ "past_key_values": past_key_values,
+ "use_cache": use_cache,
+ "token_idx": token_idx,
+ }
diff --git a/optimum/habana/transformers/trainer.py b/optimum/habana/transformers/trainer.py
index 843f646b14..27eab623bb 100644
--- a/optimum/habana/transformers/trainer.py
+++ b/optimum/habana/transformers/trainer.py
@@ -252,7 +252,7 @@ def __init__(
"The argument `--bf16` was not given but `use_torch_autocast` is True in the Gaudi configuration so mixed-precision training with Torch Autocast is enabled."
)
- if self.use_hpu_amp and "LOWER_LIST" not in os.environ:
+ if self.use_hpu_amp and "PT_HPU_AUTOCAST_LOWER_PRECISION_OPS_LIST" not in os.environ:
self.gaudi_config.declare_autocast_bf16_fp32_ops()
if self.args.use_lazy_mode:
diff --git a/setup.py b/setup.py
index ead947d01c..4249e21924 100644
--- a/setup.py
+++ b/setup.py
@@ -35,7 +35,7 @@
"accelerate >= 0.33.0, < 0.34.0",
"diffusers == 0.29.2",
"huggingface_hub >= 0.24.7",
- "sentence-transformers[train] == 3.0.1",
+ "sentence-transformers == 3.2.1",
]
TESTS_REQUIRE = [
diff --git a/tests/baselines/Llama_3_2_11B_Vision_Instruct.json b/tests/baselines/Llama_3_2_11B_Vision_Instruct.json
new file mode 100644
index 0000000000..3789c63fa9
--- /dev/null
+++ b/tests/baselines/Llama_3_2_11B_Vision_Instruct.json
@@ -0,0 +1,38 @@
+{
+ "gaudi2": {
+ "image2text_lora_finetune": {
+ "num_train_epochs": 2,
+ "eval_batch_size": 4,
+ "distribution": {
+ "multi_card": {
+ "learning_rate": 5e-5,
+ "train_batch_size": 2,
+ "train_runtime": 470,
+ "train_samples_per_second": 22,
+ "eval_accuracy": 0.6,
+ "extra_arguments": [
+ "--bf16",
+ "--gradient_accumulation_steps 8",
+ "--eval_strategy no",
+ "--save_strategy no",
+ "--warmup_steps 50",
+ "--lr_scheduler_type constant",
+ "--max_grad_norm 0.3",
+ "--logging_steps 1",
+ "--use_hpu_graphs_for_inference",
+ "--lora_rank 8",
+ "--lora_alpha 8",
+ "--lora_dropout 0.1",
+ "--lora_target_modules '.*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$'",
+ "--low_cpu_mem_usage True",
+ "--adam_epsilon 1e-08",
+ "--input_column_name image query",
+ "--output_column_name answers",
+ "--remove_unused_columns False",
+ "--max_seq_length 512"
+ ]
+ }
+ }
+ }
+ }
+}
diff --git a/tests/baselines/idefics2_8b.json b/tests/baselines/idefics2_8b.json
new file mode 100644
index 0000000000..f40995c72d
--- /dev/null
+++ b/tests/baselines/idefics2_8b.json
@@ -0,0 +1,38 @@
+{
+ "gaudi2": {
+ "image2text_lora_finetune": {
+ "num_train_epochs": 2,
+ "eval_batch_size": 4,
+ "distribution": {
+ "multi_card": {
+ "learning_rate": 5e-5,
+ "train_batch_size": 2,
+ "train_runtime": 286,
+ "train_samples_per_second": 11.8,
+ "eval_accuracy": 0.6,
+ "extra_arguments": [
+ "--bf16",
+ "--gradient_accumulation_steps 8",
+ "--eval_strategy no",
+ "--save_strategy no",
+ "--warmup_steps 50",
+ "--lr_scheduler_type constant",
+ "--max_grad_norm 0.3",
+ "--logging_steps 1",
+ "--use_hpu_graphs_for_inference",
+ "--lora_rank 8",
+ "--lora_alpha 8",
+ "--lora_dropout 0.1",
+ "--lora_target_modules '.*(text_model|modality_projection|perceiver_resampler).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$'",
+ "--low_cpu_mem_usage True",
+ "--adam_epsilon 1e-08",
+ "--input_column_name image query",
+ "--output_column_name answers",
+ "--remove_unused_columns False",
+ "--max_seq_length 512"
+ ]
+ }
+ }
+ }
+ }
+}
diff --git a/tests/test_diffusers.py b/tests/test_diffusers.py
index 35556c57bf..b2526c7fa6 100755
--- a/tests/test_diffusers.py
+++ b/tests/test_diffusers.py
@@ -22,6 +22,7 @@
import os
import random
import re
+import shutil
import subprocess
import tempfile
import time
@@ -136,8 +137,10 @@
INPAINT_XL_THROUGHPUT_BASELINE_BF16 = 1.151
TEXT_TO_VIDEO_SYNTHESIS_BF16_BASELINE = 70
DETERMINISTIC_IMAGE_GENERATION_THROUGHPUT = 0.946
- THROUGHPUT_UNCONDITIONAL_IMAGE_BASELINE_BF16 = 7.671212047338486
+ THROUGHPUT_UNCONDITIONAL_IMAGE_BASELINE_BF16 = 0.15186785472532677
DEPTH2IMG_GENERATION_LATENCY_BASELINE_BF16 = 36.06376791000366
+ TEXTUAL_INVERSION_SDXL_THROUGHPUT = 2.6694
+ TEXTUAL_INVERSION_SDXL_RUNTIME = 74.92
else:
THROUGHPUT_BASELINE_BF16 = 0.309
THROUGHPUT_BASELINE_AUTOCAST = 0.114
@@ -148,9 +151,11 @@
INPAINT_THROUGHPUT_BASELINE_BF16 = 1.42
INPAINT_XL_THROUGHPUT_BASELINE_BF16 = 0.271
DETERMINISTIC_IMAGE_GENERATION_THROUGHPUT = 0.302
- THROUGHPUT_UNCONDITIONAL_IMAGE_BASELINE_BF16 = 3.095533166996529
+ THROUGHPUT_UNCONDITIONAL_IMAGE_BASELINE_BF16 = 0.050208662346013566
TEXT_TO_VIDEO_SYNTHESIS_BF16_BASELINE = 1000 # TODO: Get Gaudi 1 benchmark numbers
DEPTH2IMG_GENERATION_LATENCY_BASELINE_BF16 = 200 # TODO: Get Gaudi 1 Throughput
+ TEXTUAL_INVERSION_SDXL_THROUGHPUT = 2.695
+ TEXTUAL_INVERSION_SDXL_RUNTIME = 74.19
_run_custom_bf16_ops_test_ = parse_flag_from_env("CUSTOM_BF16_OPS", default=False)
@@ -831,6 +836,9 @@ def test_textual_inversion(self):
snapshot_download(
"diffusers/cat_toy_example", local_dir=data_dir, repo_type="dataset", ignore_patterns=".gitattributes"
)
+ cache_dir = Path(data_dir, ".cache")
+ if cache_dir.is_dir():
+ shutil.rmtree(cache_dir)
with tempfile.TemporaryDirectory() as run_dir:
cmd_line = [
"python3",
@@ -1196,6 +1204,77 @@ def test_stable_diffusion_xl_bf16(self):
self.assertEqual(image.shape, (64, 64, 3))
+ @slow
+ def test_textual_inversion_sdxl(self):
+ path_to_script = (
+ Path(os.path.dirname(__file__)).parent
+ / "examples"
+ / "stable-diffusion"
+ / "training"
+ / "textual_inversion_sdxl.py"
+ )
+ with tempfile.TemporaryDirectory() as data_dir:
+ snapshot_download(
+ "diffusers/cat_toy_example", local_dir=data_dir, repo_type="dataset", ignore_patterns=".gitattributes"
+ )
+ cache_dir = Path(data_dir, ".cache")
+ if cache_dir.is_dir():
+ shutil.rmtree(cache_dir)
+ with tempfile.TemporaryDirectory() as run_dir:
+ cmd_line = [
+ "python3",
+ f"{path_to_script}",
+ "--pretrained_model_name_or_path stabilityai/stable-diffusion-xl-base-1.0",
+ f"--train_data_dir {data_dir}",
+ "--learnable_property object",
+ "--placeholder_token ",
+ "--initializer_token toy",
+ "--resolution 64",
+ "--train_batch_size 1",
+ "--gradient_accumulation_steps 4",
+ "--max_train_steps 50",
+ "--learning_rate 5.0e-04",
+ "--scale_lr",
+ "--lr_scheduler constant",
+ "--lr_warmup_steps 0",
+ f"--output_dir {run_dir}",
+ "--save_as_full_pipeline",
+ "--gaudi_config_name Habana/stable-diffusion",
+ "--throughput_warmup_steps 3",
+ "--seed 27",
+ ]
+
+ pattern = re.compile(r"([\"\'].+?[\"\'])|\s")
+ cmd_line = [x for y in cmd_line for x in re.split(pattern, y) if x]
+ # Run textual inversion
+ p = subprocess.Popen(cmd_line)
+ return_code = p.wait()
+
+ # Ensure the run finished without any issue
+ self.assertEqual(return_code, 0)
+
+ # Assess throughput
+ with open(Path(run_dir) / "speed_metrics.json") as fp:
+ results = json.load(fp)
+ self.assertGreaterEqual(results["train_samples_per_second"], 0.95 * TEXTUAL_INVERSION_SDXL_THROUGHPUT)
+ self.assertLessEqual(results["train_runtime"], 1.05 * TEXTUAL_INVERSION_SDXL_RUNTIME)
+
+ pipe = GaudiStableDiffusionXLPipeline.from_pretrained(
+ run_dir,
+ torch_dtype=torch.bfloat16,
+ use_habana=True,
+ use_hpu_graphs=True,
+ gaudi_config=GaudiConfig(use_habana_mixed_precision=False),
+ )
+
+ set_seed(27)
+ prompt_1 = "A backpack"
+ prompt_2 = "A colored backpack"
+ image = pipe(
+ prompt=prompt_1, prompt_2=prompt_2, num_inference_steps=50, guidance_scale=7.5, output_type="np"
+ ).images[0]
+ self.assertEqual(image.shape, (1024, 1024, 3))
+
def test_stable_diffusion_xl_default(self):
components = self.get_dummy_components()
diff --git a/tests/test_examples.py b/tests/test_examples.py
index 4fd61d9b7f..c5668e5b7c 100644
--- a/tests/test_examples.py
+++ b/tests/test_examples.py
@@ -34,6 +34,7 @@
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
+ MODEL_FOR_VISION_2_SEQ_MAPPING,
MODEL_MAPPING,
)
from transformers.testing_utils import slow
@@ -201,6 +202,11 @@ def is_valid_model_type(model_type: str) -> bool:
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
["t5"],
),
+ "run_image2text_lora_finetune": _get_supported_models_for_script(
+ MODELS_TO_TEST_MAPPING,
+ MODEL_FOR_VISION_2_SEQ_MAPPING,
+ ["idefics2", "mllama"],
+ ),
}
@@ -230,16 +236,21 @@ def to_test(
"meta-llama/LlamaGuard-7b",
]
+ case_only_in_gaudi2 = [
+ "sft",
+ "dpo",
+ "reward_modeling",
+ "ppo",
+ "prompt_tuning",
+ "peft_poly",
+ "run_sequence_classification",
+ "run_image2text_lora_finetune",
+ ]
+
if (fsdp or fp8) and not IS_GAUDI2:
return False
elif (
- "sft" in example_name
- or "dpo" in example_name
- or "reward_modeling" in example_name
- or "ppo" in example_name
- or "prompt_tuning" in example_name
- or "peft_poly" in example_name
- or example_name == "run_sequence_classification"
+ any(case in example_name for case in case_only_in_gaudi2)
or task_name in ("llama-adapter", "vera", "ia3", "adalora", "ln_tuning", "mamamiya405/finred")
) and not IS_GAUDI2:
return False
@@ -411,10 +422,9 @@ def test(self):
create_clip_roberta_model()
self._install_requirements(example_script.parent / "requirements.txt")
-
- path_to_baseline = BASELINE_DIRECTORY / Path(model_name.split("/")[-1].replace("-", "_")).with_suffix(
- ".json"
- )
+ path_to_baseline = BASELINE_DIRECTORY / Path(
+ model_name.split("/")[-1].replace("-", "_").replace(".", "_")
+ ).with_suffix(".json")
with path_to_baseline.open("r") as json_file:
device = "gaudi2" if IS_GAUDI2 else "gaudi"
baseline = json.load(json_file)[device]
@@ -439,7 +449,7 @@ def test(self):
env_variables = os.environ.copy()
if "falcon" in model_name:
- env_variables["LOWER_LIST"] = str(example_script.parent / "ops_bf16.txt")
+ env_variables["PT_HPU_AUTOCAST_LOWER_PRECISION_OPS_LIST"] = str(example_script.parent / "ops_bf16.txt")
elif "flan" in model_name:
env_variables["PT_HPU_MAX_COMPOUND_OP_SIZE"] = "512"
elif "bloom" in model_name:
@@ -450,13 +460,15 @@ def test(self):
env_variables["DEEPSPEED_HPU_ZERO3_SYNC_MARK_STEP_REQUIRED"] = "1"
elif fsdp:
if "llama" in model_name:
- env_variables["LOWER_LIST"] = str(example_script.parent / "ops_bf16.txt")
+ env_variables["PT_HPU_AUTOCAST_LOWER_PRECISION_OPS_LIST"] = str(
+ example_script.parent / "ops_bf16.txt"
+ )
env_variables["PT_HPU_LAZY_MODE"] = "0"
elif deepspeed and "gpt-neox-20b" in model_name:
env_variables["LD_PRELOAD"] = ""
if fp8 and "llama" in model_name:
- env_variables["LOWER_LIST"] = str(example_script.parent / "ops_bf16.txt")
+ env_variables["PT_HPU_AUTOCAST_LOWER_PRECISION_OPS_LIST"] = str(example_script.parent / "ops_bf16.txt")
extra_command_line_arguments = baseline.get("distribution").get(distribution).get("extra_arguments", [])
@@ -922,6 +934,16 @@ class MultiCardCausalLanguageModelingLoRAFP8ExampleTester(
DATASET_NAME = "tatsu-lab/alpaca"
+class MultiCardImageToTextModelingLoRAExampleTester(
+ ExampleTesterBase,
+ metaclass=ExampleTestMeta,
+ example_name="run_image2text_lora_finetune",
+ multi_card=True,
+):
+ TASK_NAME = "image2text_lora_finetune"
+ DATASET_NAME = "nielsr/docvqa_1200_examples"
+
+
class MultiCardCausalLanguageModelingVeraExampleTester(
ExampleTesterBase, metaclass=ExampleTestMeta, example_name="run_lora_clm", multi_card=True
):
diff --git a/tests/test_functional_text_generation_example.py b/tests/test_functional_text_generation_example.py
new file mode 100644
index 0000000000..9012cabd65
--- /dev/null
+++ b/tests/test_functional_text_generation_example.py
@@ -0,0 +1,75 @@
+import json
+import os
+import re
+import subprocess
+from pathlib import Path
+from tempfile import TemporaryDirectory
+
+import pytest
+
+from optimum.habana.utils import set_seed
+
+
+if os.environ.get("GAUDI2_CI", "0") == "1":
+ MODEL_OUTPUTS = {
+ "bigcode/starcoder": 'def print_hello_world():\n print("Hello World")\n\ndef print_hello_world_twice():\n print_hello_world()\n print_hello_world()\n\ndef print_hello_world_thrice():\n print_hello_world()\n print_hello_world()\n print_hello_world()\n\ndef print_hello_world_four_times():\n print_hello_world()\n print_hello_world()\n print_hello_world()\n ',
+ "bigcode/starcoder2-3b": 'def print_hello_world():\n print("Hello World")\n\ndef print_hello_world_with_name(name):\n print("Hello World, " + name)\n\ndef print_hello_world_with_name_and_age(name, age):\n print("Hello World, " + name + ", " + str(age))\n\ndef print_hello_world_with_name_and_age_and_gender(name, age, gender):\n print("Hello',
+ "google/gemma-7b": "DeepSpeed is a machine learning framework that enables training of large-scale models on commodity hardware. It is designed to be a drop-in replacement for PyTorch, and it is compatible with the existing PyTorch ecosystem. DeepSpeed is designed to be easy to use, and it provides a number of features that make it easy to train large-scale models.\n\nDeepSpeed is a machine learning framework that enables training of large-scale models on commodity hardware. It is designed to be a drop-in replacement for PyTorch, and",
+ "meta-llama/Llama-2-7b-hf": "DeepSpeed is a machine learning framework for deep learning. It is designed to be fast and efficient, while also being easy to use. DeepSpeed is based on the TensorFlow framework, and it uses the TensorFlow library to perform computations.\nDeepSpeed is a deep learning framework that is designed to be fast and efficient. It is based on the TensorFlow library and uses the TensorFlow library to perform computations. DeepSpeed is designed to be easy to use and to provide a high level of flex",
+ "mistralai/Mistral-7B-v0.1": "DeepSpeed is a machine learning framework that accelerates training of large models on a single machine or distributed systems. It is designed to be compatible with PyTorch and TensorFlow, and can be used to train models on a single machine or on a distributed system.\n\nDeepSpeed is a machine learning framework that accelerates training of large models on a single machine or distributed systems. It is designed to be compatible with PyTorch and TensorFlow, and can be used to train models on a single machine or on a distributed system",
+ "mistralai/Mixtral-8x7B-v0.1": "DeepSpeed is a machine learning framework that enables training of large models on a single machine with a single GPU. It is designed to be easy to use and efficient, and it can be used to train models on a variety of tasks.\n\n## Introduction\n\nDeepSpeed is a machine learning framework that enables training of large models on a single machine with a single GPU. It is designed to be easy to use and efficient, and it can be used to train models on a variety of tasks.\n\n## What is DeepSpeed",
+ "Qwen/Qwen2-7B": "DeepSpeed is a machine learning framework that provides a unified interface for training deep learning models. It is designed to be easy to use and to provide high performance on a variety of hardware platforms. DeepSpeed is built on top of PyTorch and TensorFlow, and it supports a wide range of models architectures, including transformer models, convolutional neural networks, and recurrent neural networks.\nDeepSpeed is designed to be easy to use, and it provides a unified interface for training deep learning models. It supports a wide range of model architectures, including",
+ }
+else:
+ # Functional testing only on G2 onwards
+ MODEL_OUTPUTS = []
+
+
+def _test_text_generation(
+ model_name: str,
+ token: str,
+):
+ set_seed(42)
+ command = ["python3"]
+ path_to_example_dir = Path(__file__).resolve().parent.parent / "examples"
+ env_variables = os.environ.copy()
+
+ command += [
+ f"{path_to_example_dir}/text-generation/run_generation.py ",
+ f"--model_name_or_path {model_name}",
+ "--use_kv_cache",
+ "--use_hpu_graphs",
+ "--bf16",
+ ]
+
+ with TemporaryDirectory() as tmp_dir:
+ command.append(f"--output_dir {tmp_dir}")
+ command.append(f"--token {token.value}")
+
+ pattern = re.compile(r"([\"\"].+?[\"\"])|\s")
+
+ command = [x for y in command for x in re.split(pattern, y) if x]
+ if "starcoder" in model_name:
+ command.append("--prompt")
+ command.append("def print_hello_world():")
+ print(f"\n\nCommand to test: {' '.join(command)}\n")
+ proc = subprocess.run(command, env=env_variables)
+
+ # Ensure the run finished without any issue
+ # Use try-except to avoid logging the token if used
+ try:
+ assert proc.returncode == 0
+ except AssertionError as e:
+ if "'--token', 'hf_" in e.args[0]:
+ e.args = (f"The following command failed:\n{' '.join(command[:-2])}",)
+ raise
+
+ with open(Path(tmp_dir) / "results.json") as fp:
+ results = json.load(fp)
+
+ assert results["output"][0][0] == MODEL_OUTPUTS[model_name]
+
+
+@pytest.mark.parametrize("model_name", MODEL_OUTPUTS.keys())
+def test_text_generation_bf16_1x(model_name: str, token: str):
+ _test_text_generation(model_name, token)
diff --git a/tests/test_image_to_text_example.py b/tests/test_image_to_text_example.py
index 81d6ec3d19..60049bf46e 100644
--- a/tests/test_image_to_text_example.py
+++ b/tests/test_image_to_text_example.py
@@ -19,6 +19,8 @@
("llava-hf/llava-v1.6-mistral-7b-hf", 1, 33.17984878151546),
("llava-hf/llava-v1.6-vicuna-7b-hf", 1, 35.00608681379742),
("llava-hf/llava-v1.6-vicuna-13b-hf", 1, 23.527610042925),
+ ("HuggingFaceM4/idefics2-8b", 1, 21.89944593215077),
+ ("meta-llama/Llama-3.2-11B-Vision-Instruct", 1, 20.407843538649303),
],
"fp8": [
("llava-hf/llava-1.5-7b-hf", 1, 98.72578382705062),
diff --git a/tests/test_text_generation_example.py b/tests/test_text_generation_example.py
index 96d6043f36..f4a2ce059a 100644
--- a/tests/test_text_generation_example.py
+++ b/tests/test_text_generation_example.py
@@ -9,6 +9,8 @@
import pytest
+from optimum.habana.utils import set_seed
+
from .test_examples import TIME_PERF_FACTOR
@@ -19,32 +21,40 @@
# Gaudi2 CI baselines
MODELS_TO_TEST = {
"bf16_1x": [
- ("bigscience/bloomz-7b1", 1, False, 130.0472971205316),
- ("gpt2-xl", 1, False, 281.8734689674413),
- ("EleutherAI/gpt-j-6b", 1, False, 160.5823842101192),
- ("EleutherAI/gpt-neox-20b", 1, False, 50.67672679310354),
- ("meta-llama/Llama-2-7b-hf", 1, True, 141.25776956002076),
- ("tiiuae/falcon-40b", 1, True, 25.202450111088346),
- ("bigcode/starcoder", 256, True, 6846.575763562658),
- ("Salesforce/codegen2-1B", 1, False, 446.4029486883532),
- ("mosaicml/mpt-30b", 1, False, 36.06464336116623),
- ("mistralai/Mistral-7B-v0.1", 1, True, 130.2172236767782),
- ("mistralai/Mixtral-8x7B-v0.1", 1, False, 23.7931001677926),
- ("microsoft/phi-2", 1, False, 224.72307766211117),
- ("meta-llama/Meta-Llama-3-8B", 1, True, 129),
- ("meta-llama/Llama-2-7b-hf", 512, True, 12808),
- ("meta-llama/Llama-2-7b-hf", 512, False, 8711), # in some cases like TGI, reuse_cache isnt used
- ("stabilityai/stablelm-2-12b", 1, False, 74.8904496532218),
- ("codellama/CodeLlama-34b-hf", 1, True, 32.644),
- ("bigcode/starcoder2-3b", 1, False, 261.07213776344133),
- ("adept/persimmon-8b-base", 4, False, 366.73968820698406),
- ("Qwen/Qwen1.5-7B", 4, False, 490.8621617893209),
- ("google/gemma-7b", 1, False, 109.70751574382221),
- ("state-spaces/mamba-130m-hf", 1536, False, 5385.511100161605),
- ("Deci/DeciLM-7B", 1, False, 120),
- ("Qwen/Qwen2-7B", 512, False, 9669.45787),
- ("Qwen/Qwen1.5-MoE-A2.7B", 1, True, 44.25834541569395),
- ("EleutherAI/gpt-neo-2.7B", 1, False, 257.2476416844122),
+ ("bigscience/bloomz-7b1", 1, False, 130.0472971205316, False),
+ ("gpt2-xl", 1, False, 281.8734689674413, False),
+ ("EleutherAI/gpt-j-6b", 1, False, 160.5823842101192, False),
+ ("EleutherAI/gpt-neox-20b", 1, False, 50.67672679310354, False),
+ ("meta-llama/Llama-2-7b-hf", 1, True, 141.25776956002076, True),
+ ("tiiuae/falcon-40b", 1, True, 25.202450111088346, False),
+ (
+ "bigcode/starcoder",
+ 256,
+ True,
+ 6846.575763562658,
+ False,
+ ), # TODO: Enable check_output after model bigcode/starcoder is fixed
+ ("Salesforce/codegen2-1B", 1, False, 446.4029486883532, False),
+ ("mosaicml/mpt-30b", 1, False, 36.06464336116623, False),
+ ("mistralai/Mistral-7B-v0.1", 1, True, 130.2172236767782, True),
+ ("mistralai/Mixtral-8x7B-v0.1", 1, False, 23.7931001677926, True),
+ ("microsoft/phi-2", 1, False, 224.72307766211117, False),
+ ("meta-llama/Meta-Llama-3-8B", 1, True, 129, False),
+ ("meta-llama/Llama-2-7b-hf", 512, True, 12808, False),
+ ("meta-llama/Llama-2-7b-hf", 512, False, 8711, False), # in some cases like TGI, reuse_cache isnt used
+ ("stabilityai/stablelm-2-12b", 1, False, 74.8904496532218, False),
+ ("codellama/CodeLlama-34b-hf", 1, True, 32.644, False),
+ ("bigcode/starcoder2-3b", 1, False, 261.07213776344133, True),
+ ("adept/persimmon-8b-base", 4, False, 366.73968820698406, False),
+ ("Qwen/Qwen1.5-7B", 4, False, 490.8621617893209, False),
+ ("google/gemma-7b", 1, False, 109.70751574382221, True),
+ ("state-spaces/mamba-130m-hf", 1536, False, 5385.511100161605, False),
+ ("Deci/DeciLM-7B", 1, False, 120, False),
+ ("Qwen/Qwen2-7B", 512, False, 9669.45787, True),
+ ("Qwen/Qwen1.5-MoE-A2.7B", 1, True, 44.25834541569395, False),
+ ("EleutherAI/gpt-neo-2.7B", 1, False, 257.2476416844122, False),
+ ("facebook/xglm-1.7B", 1, False, 357.46365062825083, False),
+ ("CohereForAI/c4ai-command-r-v01", 1, False, 29.50315234651154, False),
],
"fp8": [
("tiiuae/falcon-180B", 4, 950, True, 128, 128, 2506.68),
@@ -89,41 +99,51 @@
("gpt2-xl", 1, False, 51.61471298016438),
],
}
+ MODEL_OUTPUTS = {
+ "bigcode/starcoder": 'def print_hello_world():\n print("Hello World")\n\ndef print_hello_world_twice():\n print_hello_world()\n print_hello_world()\n\ndef print_hello_world_thrice():\n print_hello_world()\n print_hello_world()\n print_hello_world()\n\ndef print_hello_world_four_times():\n print_hello_world()\n print_hello_world()\n print_hello_world()\n ',
+ "bigcode/starcoder2-3b": 'def print_hello_world():\n print("Hello World")\n\ndef print_hello_world_with_name(name):\n print("Hello World, " + name)\n\ndef print_hello_world_with_name_and_age(name, age):\n print("Hello World, " + name + ", " + str(age))\n\ndef print_hello_world_with_name_and_age_and_gender(name, age, gender):\n print("Hello',
+ "google/gemma-7b": "DeepSpeed is a machine learning framework that enables training of large-scale models on commodity hardware. It is designed to be a drop-in replacement for PyTorch, and it is compatible with the existing PyTorch ecosystem. DeepSpeed is designed to be easy to use, and it provides a number of features that make it easy to train large-scale models.\n\nDeepSpeed is a machine learning framework that enables training of large-scale models on commodity hardware. It is designed to be a drop-in replacement for PyTorch, and",
+ "meta-llama/Llama-2-7b-hf": "DeepSpeed is a machine learning framework for deep learning. It is designed to be fast and efficient, while also being easy to use. DeepSpeed is based on the TensorFlow framework, and it uses the TensorFlow library to perform computations.\nDeepSpeed is a deep learning framework that is designed to be fast and efficient. It is based on the TensorFlow library and uses the TensorFlow library to perform computations. DeepSpeed is designed to be easy to use and to provide a high level of flex",
+ "mistralai/Mistral-7B-v0.1": "DeepSpeed is a machine learning framework that accelerates training of large models on a single machine or distributed systems. It is designed to be compatible with PyTorch and TensorFlow, and can be used to train models on a single machine or on a distributed system.\n\nDeepSpeed is a machine learning framework that accelerates training of large models on a single machine or distributed systems. It is designed to be compatible with PyTorch and TensorFlow, and can be used to train models on a single machine or on a distributed system",
+ "mistralai/Mixtral-8x7B-v0.1": "DeepSpeed is a machine learning framework that enables training of large models on a single machine with a single GPU. It is designed to be easy to use and efficient, and it can be used to train models on a variety of tasks.\n\n## Introduction\n\nDeepSpeed is a machine learning framework that enables training of large models on a single machine with a single GPU. It is designed to be easy to use and efficient, and it can be used to train models on a variety of tasks.\n\n## What is DeepSpeed",
+ "Qwen/Qwen2-7B": "DeepSpeed is a machine learning framework that provides a suite of toolskits for building and training deep learning models. It is designed to be highly scalable and efficient, and it supports a wide range of deep learning frameworks, including PyTorch, TensorFlow, and MXNet. DeepSpeed is particularly well-suited for training large-scale models on distributed systems, and it provides a number of features that make it easy to use and configure. Some of the key features of DeepSpeed include:\n\n- Distributed training: DeepSpeed supports distributed training on multiple",
+ }
else:
# Gaudi1 CI baselines
MODELS_TO_TEST = {
"bf16_1x": [
- ("bigscience/bloomz-7b1", 1, False, 41.7555095197846),
- ("gpt2-xl", 1, False, 142.11481820425706),
+ ("bigscience/bloomz-7b1", 1, False, 41.7555095197846, False),
+ ("gpt2-xl", 1, False, 142.11481820425706, False),
# TODO: fix OPT 6.7B
# ("facebook/opt-6.7b", 0.0),
- ("EleutherAI/gpt-j-6b", 1, True, 156.2893125740893),
- ("meta-llama/Llama-2-7b-hf", 1, True, 44.39616259946937),
- ("tiiuae/falcon-7b", 1, True, 44.82870145718665),
- ("bigcode/starcoder", 1, False, 15.945023767901013),
- ("Salesforce/codegen2-1B", 1, False, 155.32071248826423),
- ("mosaicml/mpt-7b", 1, False, 45.45168927038262),
- ("mistralai/Mistral-7B-v0.1", 1, True, 41.21906841459711),
- ("microsoft/phi-2", 1, False, 92.53083167241344),
- ("google/gemma-7b", 1, False, 28.84284625836978),
- ("stabilityai/stablelm-2-12b", 1, False, 26.80858949645992),
- ("Qwen/Qwen1.5-7B", 1, False, 39.29068423087616),
- ("adept/persimmon-8b-base", 1, False, 34.53559807384106),
- ("bigcode/starcoder2-3b", 1, False, 82.09655684566117),
- ("state-spaces/mamba-130m-hf", 224, False, 794.542),
+ ("EleutherAI/gpt-j-6b", 1, True, 156.2893125740893, False),
+ ("meta-llama/Llama-2-7b-hf", 1, True, 44.39616259946937, False),
+ ("tiiuae/falcon-7b", 1, True, 44.82870145718665, False),
+ ("bigcode/starcoder", 1, False, 15.945023767901013, False),
+ ("Salesforce/codegen2-1B", 1, False, 155.32071248826423, False),
+ ("mosaicml/mpt-7b", 1, False, 45.45168927038262, False),
+ ("mistralai/Mistral-7B-v0.1", 1, True, 41.21906841459711, False),
+ ("microsoft/phi-2", 1, False, 92.53083167241344, False),
+ ("google/gemma-7b", 1, False, 28.84284625836978, False),
+ ("stabilityai/stablelm-2-12b", 1, False, 26.80858949645992, False),
+ ("Qwen/Qwen1.5-7B", 1, False, 39.29068423087616, False),
+ ("adept/persimmon-8b-base", 1, False, 34.53559807384106, False),
+ ("bigcode/starcoder2-3b", 1, False, 82.09655684566117, False),
+ ("state-spaces/mamba-130m-hf", 224, False, 794.542, False),
],
"fp8": [],
"load_quantized_model_with_autogptq": [],
"deepspeed": [
- ("bigscience/bloomz-7b1", 8, 1, 31.994268212011505),
+ ("bigscience/bloomz-7b1", 8, 1, 31.994268212011505, False),
],
"torch_compile": [],
"torch_compile_distributed": [],
"distributed_tp": [],
"contrastive_search": [
- ("gpt2-xl", 1, False, 34.48141280163397),
+ ("gpt2-xl", 1, False, 34.48141280163397, False),
],
}
+ MODEL_OUTPUTS = {}
def _test_text_generation(
@@ -141,6 +161,7 @@ def _test_text_generation(
max_output_tokens: int = 100,
parallel_strategy: str = None,
contrastive_search: bool = False,
+ check_output: bool = False,
):
command = ["python3"]
path_to_example_dir = Path(__file__).resolve().parent.parent / "examples"
@@ -291,7 +312,13 @@ def _test_text_generation(
)
command = [x for y in command for x in re.split(pattern, y) if x]
- print(f"\n\nCommand to test: {' '.join(command[:-2])}\n")
+ if "starcoder" in model_name and check_output:
+ command.append("--prompt")
+ command.append("def print_hello_world():")
+
+ set_seed(42)
+
+ print(f"\n\nCommand to test: {' '.join(command)}\n")
proc = subprocess.run(command, env=env_variables)
# Ensure the run finished without any issue
@@ -309,10 +336,24 @@ def _test_text_generation(
# Ensure performance requirements (throughput) are met
assert results["throughput"] >= (2 - TIME_PERF_FACTOR) * baseline
+ # Verify output for 1 HPU, BF16
+ if check_output and model_name in MODEL_OUTPUTS:
+ expected_output = MODEL_OUTPUTS[model_name]
+ assert results["output"][0][0] == expected_output
+
-@pytest.mark.parametrize("model_name, batch_size, reuse_cache, baseline", MODELS_TO_TEST["bf16_1x"])
-def test_text_generation_bf16_1x(model_name: str, baseline: float, batch_size: int, reuse_cache: bool, token: str):
- _test_text_generation(model_name, baseline, token, batch_size, reuse_cache)
+@pytest.mark.parametrize("model_name, batch_size, reuse_cache, baseline, check_output", MODELS_TO_TEST["bf16_1x"])
+def test_text_generation_bf16_1x(
+ model_name: str, baseline: float, batch_size: int, reuse_cache: bool, token: str, check_output: bool
+):
+ _test_text_generation(
+ model_name=model_name,
+ baseline=baseline,
+ token=token,
+ batch_size=batch_size,
+ reuse_cache=reuse_cache,
+ check_output=check_output,
+ )
@pytest.mark.parametrize(
diff --git a/tests/transformers/tests/models/mistral/test_modeling_mistral.py b/tests/transformers/tests/models/mistral/test_modeling_mistral.py
index d6a0bb3333..962eea1b0e 100644
--- a/tests/transformers/tests/models/mistral/test_modeling_mistral.py
+++ b/tests/transformers/tests/models/mistral/test_modeling_mistral.py
@@ -297,6 +297,46 @@ class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
test_headmasking = False
test_pruning = False
+ @unittest.skip(reason="This test is not supported for Mistral")
+ def test_beam_search_generate(self):
+ pass
+
+ @unittest.skip(reason="This test is not supported for Mistral")
+ def test_beam_search_generate_dict_output(self):
+ pass
+
+ @unittest.skip(reason="This test is not supported for Mistral")
+ def test_beam_search_generate_dict_outputs_use_cache(self):
+ pass
+
+ @unittest.skip(reason="This test is not supported for Mistral")
+ def test_attention_outputs(self):
+ pass
+
+ @unittest.skip(reason="This test is not supported for Mistral")
+ def test_constrained_beam_search_generate(self):
+ pass
+
+ @unittest.skip(reason="This test is not supported for Mistral")
+ def test_constrained_beam_search_generate_dict_output(self):
+ pass
+
+ @unittest.skip(reason="This test is not supported for Mistral")
+ def test_contrastive_generate_dict_outputs_use_cache(self):
+ pass
+
+ @unittest.skip(reason="This test is not supported for Mistral")
+ def test_greedy_generate_dict_outputs(self):
+ pass
+
+ @unittest.skip(reason="This test is not supported for Mistral")
+ def test_greedy_generate_dict_outputs_use_cache(self):
+ pass
+
+ @unittest.skip(reason="This test is not supported for Mistral")
+ def test_sample_generate_dict_output(self):
+ pass
+
# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146
def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
diff --git a/tests/transformers/tests/models/mixtral/test_modeling_mixtral.py b/tests/transformers/tests/models/mixtral/test_modeling_mixtral.py
index 82c23b1f4f..1b2230aaf2 100644
--- a/tests/transformers/tests/models/mixtral/test_modeling_mixtral.py
+++ b/tests/transformers/tests/models/mixtral/test_modeling_mixtral.py
@@ -298,6 +298,46 @@ class MixtralModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
test_headmasking = False
test_pruning = False
+ @unittest.skip(reason="This test is not supported for Mixtral")
+ def test_beam_search_generate(self):
+ pass
+
+ @unittest.skip(reason="This test is not supported for Mixtral")
+ def test_beam_search_generate_dict_output(self):
+ pass
+
+ @unittest.skip(reason="This test is not supported for Mixtral")
+ def test_beam_search_generate_dict_outputs_use_cache(self):
+ pass
+
+ @unittest.skip(reason="This test is not supported for Mixtral")
+ def test_attention_outputs(self):
+ pass
+
+ @unittest.skip(reason="This test is not supported for Mixtral")
+ def test_constrained_beam_search_generate_dict_output(self):
+ pass
+
+ @unittest.skip(reason="This test is not supported for Mixtral")
+ def test_contrastive_generate_dict_outputs_use_cache(self):
+ pass
+
+ @unittest.skip(reason="This test is not supported for Mixtral")
+ def test_greedy_generate_dict_outputs(self):
+ pass
+
+ @unittest.skip(reason="This test is not supported for Mixtral")
+ def test_greedy_generate_dict_outputs_use_cache(self):
+ pass
+
+ @unittest.skip(reason="This test is not supported for Mixtral")
+ def test_retain_grad_hidden_states_attentions(self):
+ pass
+
+ @unittest.skip(reason="This test is not supported for Mixtral")
+ def test_sample_generate_dict_output(self):
+ pass
+
# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146
def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
diff --git a/tests/utils.py b/tests/utils.py
index daee779fa9..7eab1b06be 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -61,6 +61,8 @@
"code_llama": [("codellama/CodeLlama-13b-Instruct-hf", "Habana/llama")],
"protst": [("mila-intel/protst-esm1b-for-sequential-classification", "Habana/gpt2")],
"qwen2": [("Qwen/Qwen2-7B", "Habana/qwen"), ("Qwen/Qwen2-72B", "Habana/qwen")],
+ "idefics2": [("HuggingFaceM4/idefics2-8b", "Habana/gpt2")],
+ "mllama": [("meta-llama/Llama-3.2-11B-Vision-Instruct", "Habana/gpt2")],
}
MODELS_TO_TEST_FOR_QUESTION_ANSWERING = [