Skip to content

Commit

Permalink
Add mllama support (huggingface#1419)
Browse files Browse the repository at this point in the history
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
  • Loading branch information
sywangyi authored and Liangyx2 committed Jan 20, 2025
1 parent 7b0d390 commit 61454c7
Show file tree
Hide file tree
Showing 14 changed files with 1,396 additions and 38 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ The following model architectures, tasks and device distributions have been vali
| VideoMAE | | <div style="text-align:left"><li>Single card</li></div> | <li>[Video classification](https://github.com/huggingface/optimum-habana/tree/main/examples/video-classification)</li> |
| TableTransformer | | <div style="text-align:left"><li>Single card</li></div> | <li>[table object detection](https://github.com/huggingface/optimum-habana/tree/main/examples/table-detection) </li> |
| DETR | | <div style="text-align:left"><li>Single card</li></div> | <li>[object detection](https://github.com/huggingface/optimum-habana/tree/main/examples/object-detection)</li> |
| Mllama | <div style="text-align:left"><li>LoRA</li></div> | :heavy_check_mark: | <li>[image to text](https://github.com/huggingface/optimum-habana/tree/main/examples/image-to-text)</li> |
</div>
Expand Down
1 change: 1 addition & 0 deletions docs/source/index.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ In the tables below, ✅ means single-card, multi-card and DeepSpeed have all be
| VideoMAE | | <div style="text-align:left"><li>Single card</li></div> | <li>[Video classification](https://github.com/huggingface/optimum-habana/tree/main/examples/video-classification)</li> |
| TableTransformer | | <div style="text-align:left"><li>Single card</li></div> | <li>[table object detection](https://github.com/huggingface/optimum-habana/tree/main/examples/table-detection)</li> |
| DETR | | <div style="text-align:left"><li>Single card</li></div> | <li>[object detection](https://github.com/huggingface/optimum-habana/tree/main/examples/object-detection)</li> |
| Mllama | <div style="text-align:left"><li>LoRA</li></div> || <li>[image to text](https://github.com/huggingface/optimum-habana/tree/main/examples/image-to-text)</li> |

- Diffusers

Expand Down
79 changes: 79 additions & 0 deletions examples/image-to-text/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ Models that have been validated:
- [llava-hf/llava-v1.6-34b-hf](https://huggingface.co/llava-hf/llava-v1.6-34b-hf)
- [llava-hf/llama3-llava-next-8b-hf](https://huggingface.co/llava-hf/llama3-llava-next-8b-hf)
- [HuggingFaceM4/idefics2-8b](https://huggingface.co/HuggingFaceM4/idefics2-8b)
- [meta-llama/Llama-3.2-11B-Vision-Instruct](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct)

### Inference with BF16

Expand Down Expand Up @@ -102,6 +103,15 @@ python3 run_pipeline.py \
--bf16
```

To run mllama inference, use the following command:

```bash
python3 run_pipeline.py \
--model_name_or_path meta-llama/Llama-3.2-11B-Vision-Instruct \
--use_hpu_graphs \
--bf16
```

### Inference with FP8
Inference for Llava-1.5-7b, Llava-1.5-13b, Llava-v1.6-mistral-7b and Llava-v1.6-vicuna-13b in FP8 precision are enabled using [Intel Neural Compressor (INC)](https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_FP8.html), which provides model measurement and quantization capabilities in PyTorch.

Expand Down Expand Up @@ -286,6 +296,75 @@ python3 ../gaudi_spawn.py \
--lora_target_modules '".*(text_model|modality_projection|perceiver_resampler).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$"'
```

Here are single-/multi-device command examples for meta-llama/Llama-3.2-11B-Vision-Instruct.

```bash
python3 run_image2text_lora_finetune.py \
--model_name_or_path meta-llama/Llama-3.2-11B-Vision-Instruct \
--dataset_name nielsr/docvqa_1200_examples \
--bf16 True \
--output_dir ./model_lora_llama \
--num_train_epochs 2 \
--per_device_train_batch_size 2 \
--per_device_eval_batch_size 2 \
--gradient_accumulation_steps 8 \
--weight_decay 0.01 \
--logging_steps 25 \
--eval_strategy "no" \
--save_strategy "no" \
--learning_rate 5e-5 \
--warmup_steps 50 \
--lr_scheduler_type "constant" \
--input_column_names 'image' 'query' \
--output_column_names 'answers' \
--remove_unused_columns False \
--do_train \
--do_eval \
--use_habana \
--use_lazy_mode \
--lora_rank=8 \
--lora_alpha=8 \
--lora_dropout=0.1 \
--low_cpu_mem_usage True \
--max_seq_length=512 \
--use_hpu_graphs_for_inference True \
--lora_target_modules ".*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$"
```

```bash
python3 ../gaudi_spawn.py \
--world_size 8 --use_mpi run_image2text_lora_finetune.py \
--model_name_or_path meta-llama/Llama-3.2-11B-Vision-Instruct \
--dataset_name nielsr/docvqa_1200_examples \
--bf16 True \
--output_dir ./model_lora_llama \
--num_train_epochs 2 \
--per_device_train_batch_size 2 \
--per_device_eval_batch_size 2 \
--gradient_accumulation_steps 8 \
--weight_decay 0.01 \
--logging_steps 25 \
--eval_strategy "no" \
--save_strategy "no" \
--learning_rate 5e-5 \
--warmup_steps 50 \
--lr_scheduler_type "constant" \
--input_column_names 'image' 'query' \
--output_column_names 'answers' \
--remove_unused_columns False \
--do_train \
--do_eval \
--use_habana \
--use_lazy_mode \
--lora_rank=8 \
--lora_alpha=8 \
--lora_dropout=0.1 \
--low_cpu_mem_usage True \
--max_seq_length=512 \
--use_hpu_graphs_for_inference True \
--lora_target_modules '".*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$"'
```

## Multi-HPU inference

To enable multi-card inference, you must set the environment variable `PT_HPU_ENABLE_LAZY_COLLECTIVES=true`,
Expand Down
17 changes: 11 additions & 6 deletions examples/image-to-text/run_image2text_lora_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,11 +251,9 @@ class FinetuneArguments:


class MyDataCollator:
def __init__(self, processor, max_seq_length):
def __init__(self, processor, max_seq_length, image_token_id):
self.processor = processor
self.image_token_id = processor.tokenizer.additional_special_tokens_ids[
processor.tokenizer.additional_special_tokens.index("<image>")
]
self.image_token_id = image_token_id
self.max_seq_length = max_seq_length

def __call__(self, examples):
Expand Down Expand Up @@ -458,8 +456,15 @@ def main():
if col not in (data_args.input_column_names + data_args.output_column_names)
]
)

data_collator = MyDataCollator(processor, max_seq_length=data_args.max_seq_length)
if hasattr(config, "image_token_id"):
# idefics
image_token_id = config.image_token_id
elif hasattr(config, "image_token_index"):
# mllama
image_token_id = config.image_token_index
else:
raise ValueError("Please provide value for image_token_id")
data_collator = MyDataCollator(processor, max_seq_length=data_args.max_seq_length, image_token_id=image_token_id)

gaudi_config = GaudiConfig()
gaudi_config.use_fused_adam = True
Expand Down
41 changes: 27 additions & 14 deletions examples/image-to-text/run_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import PIL.Image
import requests
import torch
from transformers import AutoConfig, AutoProcessor, pipeline
from transformers import AutoConfig, AutoModelForVision2Seq, AutoProcessor, pipeline

from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi

Expand Down Expand Up @@ -185,14 +185,13 @@ def main():
adapt_transformers_to_gaudi()

model_type = AutoConfig.from_pretrained(args.model_name_or_path).model_type
if args.image_path is None and model_type in ["llava", "idefics2"]:
if args.image_path is None and model_type in ["llava", "idefics2", "mllama"]:
args.image_path = ["https://llava-vl.github.io/static/images/view.jpg"]
elif args.image_path is None and model_type == "llava_next":
args.image_path = [
"https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true"
]

if args.prompt is None and model_type in ["llava", "idefics2", "llava_next"]:
if args.prompt is None and model_type in ["llava", "idefics2", "llava_next", "mllama"]:
processor = AutoProcessor.from_pretrained(args.model_name_or_path)
conversation = [
{
Expand Down Expand Up @@ -231,17 +230,31 @@ def main():

htcore.hpu_set_env()

generator = pipeline(
"image-to-text",
model=args.model_name_or_path,
torch_dtype=model_dtype,
device="hpu",
)

if args.world_size > 1:
generator.model = initialize_distributed_model(args, generator.model, logger, model_dtype)

import deepspeed

with deepspeed.OnDevice(dtype=model_dtype, device="cpu"):
model = AutoModelForVision2Seq.from_pretrained(args.model_name_or_path, torch_dtype=model_dtype)
if model_type == "mllama":
model.language_model = initialize_distributed_model(args, model.language_model, logger, model_dtype)
else:
model = initialize_distributed_model(args, model, logger, model_dtype)
generator = pipeline(
"image-to-text",
model=model,
config=args.model_name_or_path,
tokenizer=args.model_name_or_path,
image_processor=args.model_name_or_path,
torch_dtype=model_dtype,
device="hpu",
)
else:
generator = pipeline(
"image-to-text",
model=args.model_name_or_path,
torch_dtype=model_dtype,
device="hpu",
)
if args.use_hpu_graphs:
from habana_frameworks.torch.hpu import wrap_in_hpu_graph

Expand All @@ -263,7 +276,7 @@ def main():
htcore.hpu_initialize(generator.model)

# delete once pipeline integrate AutoProcessor as preprocess engine
if model_type in ["idefics2"]:
if model_type in ["idefics2", "mllama"]:
from transformers.image_utils import load_image

def preprocess(self, image, prompt=None, timeout=None):
Expand Down
48 changes: 36 additions & 12 deletions optimum/habana/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@
"xglm",
"whisper",
"idefics2",
"mllama",
]


Expand Down Expand Up @@ -330,11 +331,13 @@ def _expand_dict_for_generation(dict_to_expand):

def _pad_past_key_values(self, model_kwargs):
pad_amount = model_kwargs.get("kv_cache_pad_len", 0)
kv_cache_len = model_kwargs.get("kv_cache_len", 0)
if model_kwargs["past_key_values"]:
if model_kwargs.get("mqa_model", False):
for i in range(len(model_kwargs["past_key_values"])): # layer
if torch.is_tensor(
model_kwargs["past_key_values"][i]
if (
torch.is_tensor(model_kwargs["past_key_values"][i])
and model_kwargs["past_key_values"][i].shape[-2] == kv_cache_len - pad_amount
): # tensor(batch_size, kv_cache_len, n_heads * head_dim * 2) k and v stacked
model_kwargs["past_key_values"][i] = torch.nn.functional.pad(
model_kwargs["past_key_values"][i], (0, 0, 0, pad_amount)
Expand All @@ -344,8 +347,9 @@ def _pad_past_key_values(self, model_kwargs):
else:
for i in range(len(model_kwargs["past_key_values"])): # layer
for j in range(len(model_kwargs["past_key_values"][i])): # k or v
if torch.is_tensor(
model_kwargs["past_key_values"][i][j]
if (
torch.is_tensor(model_kwargs["past_key_values"][i][j])
and model_kwargs["past_key_values"][i][j].shape[-2] == kv_cache_len - pad_amount
): # tensor(batch_size, n_heads, kv_cache_len, head_dim)
model_kwargs["past_key_values"][i][j] = torch.nn.functional.pad(
model_kwargs["past_key_values"][i][j], (0, 0, 0, pad_amount)
Expand Down Expand Up @@ -461,6 +465,14 @@ def update_model_kwargs_for_bucketing(
)
else:
assert False, "Not tested for cases where attn_mask isnt passed"

if model_kwargs.get("cross_attention_mask") is not None:
model_kwargs["cross_attention_mask"] = torch.nn.functional.pad(
model_kwargs["cross_attention_mask"],
(0, 0, 0, 0, 0, pad_amount),
value=0,
)

if reduce_recompile and params["passnum"] == 0:
position_ids_cpu = model_kwargs["attention_mask"].long().cumsum(-1) - 1
position_ids_cpu.masked_fill_(model_kwargs["attention_mask"] == 0, 1)
Expand Down Expand Up @@ -503,14 +515,20 @@ def create_pad_arg(pad_amount, i, j):
# This is a necessary (but not sufficient) condition: what ever dimension we are padding, should be a multiple of bucket_size
# This check is added in case we get a new model with a new kv-cache structure, and we attempt to pad some wrong dimension
# in peft case, if there's virtual token. the model_kwargs["past_key_values"][i][j].shape[-(len(pad_tuple) // 2)] % bucket_size == num_virtual_token, no need of assert, the pad length of past_key_value should be aligned with input id and attention_mask
num_virtual_tokens = model_kwargs.get("num_virtual_tokens", 0)
assert (
model_kwargs["past_key_values"][i][j].shape[-(len(pad_tuple) // 2)] % bucket_size
== num_virtual_tokens
)
tmp_lst[j] = torch.nn.functional.pad(
model_kwargs["past_key_values"][i][j], pad_tuple, value=pad_token_id
)
if (
model_kwargs["past_key_values"][i][j].shape[-(len(pad_tuple) // 2)]
== params["allocated_space"] - pad_amount
):
num_virtual_tokens = model_kwargs.get("num_virtual_tokens", 0)
assert (
model_kwargs["past_key_values"][i][j].shape[-(len(pad_tuple) // 2)] % bucket_size
== num_virtual_tokens
)
tmp_lst[j] = torch.nn.functional.pad(
model_kwargs["past_key_values"][i][j], pad_tuple, value=pad_token_id
)
else:
tmp_lst[j] = model_kwargs["past_key_values"][i][j]
new_kv[i] = tuple(tmp_lst)
model_kwargs["past_key_values"] = tuple(new_kv)

Expand Down Expand Up @@ -1110,6 +1128,12 @@ def generate(
(0, generation_config.max_new_tokens),
value=0,
)
if model_kwargs.get("cross_attention_mask") is not None:
model_kwargs["cross_attention_mask"] = torch.nn.functional.pad(
model_kwargs["cross_attention_mask"],
(0, 0, 0, 0, 0, generation_config.max_new_tokens),
value=0,
)
else:
assert generation_config.bucket_size <= 0, "Untested path for bucket>0"
if model_kwargs.get("decoder_input_ids", None) is None:
Expand Down
18 changes: 18 additions & 0 deletions optimum/habana/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,14 @@
GaudiMixtralDecoderLayer,
GaudiMixtralForCausalLM,
GaudiMixtralModel,
GaudiMllamaCrossAttentionDecoderLayer,
GaudiMllamaForCausalLM,
GaudiMllamaForConditionalGeneration,
GaudiMllamaSelfAttentionDecoderLayer,
GaudiMllamaTextCrossAttention,
GaudiMllamaTextModel,
GaudiMllamaTextSelfAttention,
GaudiMllamaVisionModel,
GaudiMptAttention,
GaudiMptBlock,
GaudiMptForCausalLM,
Expand Down Expand Up @@ -622,6 +630,16 @@ def adapt_transformers_to_gaudi():
transformers.models.whisper.modeling_whisper.WhisperForConditionalGeneration = GaudiWhisperForConditionalGeneration
transformers.models.whisper.modeling_whisper.WHISPER_ATTENTION_CLASSES = GAUDI_WHISPER_ATTENTION_CLASSES

# Optimization for mllama on Gaudi
transformers.models.mllama.modeling_mllama.MllamaSelfAttentionDecoderLayer = GaudiMllamaSelfAttentionDecoderLayer
transformers.models.mllama.modeling_mllama.MllamaCrossAttentionDecoderLayer = GaudiMllamaCrossAttentionDecoderLayer
transformers.models.mllama.modeling_mllama.MllamaForCausalLM = GaudiMllamaForCausalLM
transformers.models.mllama.modeling_mllama.MllamaTextSelfAttention = GaudiMllamaTextSelfAttention
transformers.models.mllama.modeling_mllama.MllamaTextCrossAttention = GaudiMllamaTextCrossAttention
transformers.models.mllama.modeling_mllama.MllamaForConditionalGeneration = GaudiMllamaForConditionalGeneration
transformers.models.mllama.modeling_mllama.MllamaTextModel = GaudiMllamaTextModel
transformers.models.mllama.modeling_mllama.MllamaVisionModel = GaudiMllamaVisionModel

transformers.AutoConfig.register("deci", DeciLMConfig)
transformers.AutoModelForCausalLM.register(DeciLMConfig, DeciLMForCausalLM)

Expand Down
10 changes: 10 additions & 0 deletions optimum/habana/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,16 @@
gaudi_mixtral_block_sparse_moe_forward,
gaudi_mixtral_rmsnorm_forward,
)
from .mllama import (
GaudiMllamaCrossAttentionDecoderLayer,
GaudiMllamaForCausalLM,
GaudiMllamaForConditionalGeneration,
GaudiMllamaSelfAttentionDecoderLayer,
GaudiMllamaTextCrossAttention,
GaudiMllamaTextModel,
GaudiMllamaTextSelfAttention,
GaudiMllamaVisionModel,
)
from .modeling_all_models import (
gaudi_check_and_enable_sdpa,
gaudi_conv1d_forward,
Expand Down
10 changes: 10 additions & 0 deletions optimum/habana/transformers/models/mllama/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from .modeling_mllama import (
GaudiMllamaCrossAttentionDecoderLayer,
GaudiMllamaForCausalLM,
GaudiMllamaForConditionalGeneration,
GaudiMllamaSelfAttentionDecoderLayer,
GaudiMllamaTextCrossAttention,
GaudiMllamaTextModel,
GaudiMllamaTextSelfAttention,
GaudiMllamaVisionModel,
)
Loading

0 comments on commit 61454c7

Please sign in to comment.