Skip to content

Commit

Permalink
Add mixtral trl sft (#1349)
Browse files Browse the repository at this point in the history
  • Loading branch information
lkk12014402 authored Nov 28, 2024
1 parent 0a9aeba commit 22c6adb
Show file tree
Hide file tree
Showing 3 changed files with 238 additions and 1 deletion.
91 changes: 90 additions & 1 deletion examples/trl/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@ First, you should install the requirements:
$ pip install -U -r requirements.txt
```
## Supervised Finetuning
The following example is for the supervised Lora finetune with Qwen2 model for conversational format dataset.

1. The following example is for the supervised Lora finetune with Qwen2 model for conversational format dataset.

```
python sft.py \
--model_name_or_path "Qwen/Qwen2-7B" \
--dataset_name "philschmid/dolly-15k-oai-style" \
Expand Down Expand Up @@ -38,11 +40,46 @@ The following example is for the supervised Lora finetune with Qwen2 model for c
--lora_target_modules "q_proj" "v_proj" "k_proj" "o_proj" \
--max_seq_length 512 \
--adam_epsilon 1e-08
```
2. Supervised fine-tuning of the mistralai/Mixtral-8x7B-v0.1 on 4 cards:
```
DEEPSPEED_HPU_ZERO3_SYNC_MARK_STEP_REQUIRED=1 python ../gaudi_spawn.py --world_size 4 --use_deepspeed sft.py \
--model_name_or_path mistralai/Mixtral-8x7B-v0.1 \
--dataset_name "philschmid/dolly-15k-oai-style" \
--subset 'data/' \
--streaming False \
--deepspeed ../language-modeling/llama2_ds_zero3_config.json \
--output_dir="./model_mixtral" \
--do_train \
--max_steps=500 \
--logging_steps=10 \
--save_steps=100 \
--per_device_train_batch_size=2 \
--per_device_eval_batch_size=1 \
--gradient_accumulation_steps=2 \
--learning_rate=1e-4 \
--lr_scheduler_type="cosine" \
--warmup_steps=100 \
--weight_decay=0.05 \
--optim="paged_adamw_32bit" \
--lora_target_modules "q_proj" "v_proj" \
--bf16 \
--remove_unused_columns=False \
--max_seq_length 512 \
--run_name="sft_mixtral" \
--report_to=none \
--use_habana \
--use_lazy_mode
```
## DPO pipeline
### Training
#### For meta-llama/Llama-2-7b-hf
The following example is for the creation of StackLlaMa 2: a Stack exchange llama-v2-7b model.
There are two main steps to the DPO training process:
1. Supervised fine-tuning of the base llama-v2-7b model to create llama-v2-7b-se:
Expand Down Expand Up @@ -86,6 +123,58 @@ There are two main steps to the DPO training process:
--output_dir="dpo" \
--report_to=none
```
#### mistralai/Mistral-7B-v0.1
1. Supervised fine-tuning of the base Mistral-7B-v0.1 model:
```
DEEPSPEED_HPU_ZERO3_SYNC_MARK_STEP_REQUIRED=1 python ../gaudi_spawn.py --world_size 8 --use_deepspeed sft.py \
--model_name_or_path mistralai/Mistral-7B-v0.1 \
--dataset_name "lvwerra/stack-exchange-paired" \
--deepspeed ../language-modeling/llama2_ds_zero3_config.json \
--output_dir="./sft" \
--do_train \
--max_steps=500 \
--logging_steps=10 \
--save_steps=100 \
--per_device_train_batch_size=1 \
--per_device_eval_batch_size=1 \
--gradient_accumulation_steps=2 \
--learning_rate=1e-4 \
--lr_scheduler_type="cosine" \
--warmup_steps=100 \
--weight_decay=0.05 \
--optim="paged_adamw_32bit" \
--lora_target_modules "q_proj" "v_proj" \
--bf16 \
--remove_unused_columns=False \
--run_name="sft_mistral" \
--report_to=none \
--use_habana \
--use_lazy_mode
```
To merge the adaptors to get the final sft merged checkpoint, we can use the `merge_peft_adapter.py` helper script that comes with TRL:
```
python merge_peft_adapter.py --base_model_name="mistralai/Mistral-7B-v0.1" --adapter_model_name="sft" --output_name="sft/final_merged_checkpoint"
```
2. Run the DPO trainer using the model saved by the previous step:
```
DEEPSPEED_HPU_ZERO3_SYNC_MARK_STEP_REQUIRED=1 python ../gaudi_spawn.py --world_size 8 --use_deepspeed dpo.py \
--model_name_or_path="sft/final_merged_checkpoint" \
--tokenizer_name_or_path=mistralai/Mistral-7B-v0.1 \
--deepspeed ../language-modeling/llama2_ds_zero3_config.json \
--lora_target_modules "q_proj" "v_proj" "k_proj" "out_proj" "fc_in" "fc_out" "wte" \
--output_dir="dpo" \
--max_prompt_length=256 \
--max_length=512 \
--report_to=none
```
#### For meta-llama/Llama-2-70b-hf
For large model like Llama2-70B, we could use DeepSpeed Zero-3 to enable DPO training in multi-card.
steps like:
1. Supervised fine-tuning of the base llama-v2-70b model to create llama-v2-70b-se:
Expand Down
1 change: 1 addition & 0 deletions optimum/habana/trl/trainer/sft_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class GaudiSFTConfig(GaudiTrainingArguments):
dataset_text_field: Optional[str] = None
packing: Optional[bool] = True
max_seq_length: Optional[int] = 1024
pad_max: Optional[bool] = True
dataset_num_proc: Optional[int] = None
dataset_batch_size: int = 1000
neftune_noise_alpha: Optional[float] = None
Expand Down
147 changes: 147 additions & 0 deletions optimum/habana/trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from collections.abc import Mapping
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import datasets
import numpy as np
import torch
import torch.nn as nn
Expand All @@ -37,6 +38,7 @@
from trl.extras.dataset_formatting import get_formatting_func_from_dataset
from trl.import_utils import is_peft_available
from trl.trainer.utils import (
ConstantLengthDataset,
DataCollatorForCompletionOnlyLM,
RichProgressCallback,
)
Expand Down Expand Up @@ -109,6 +111,7 @@ def __init__(
packing: Optional[bool] = False,
formatting_func: Optional[Callable] = None,
max_seq_length: Optional[int] = None,
pad_max: Optional[bool] = None,
infinite: Optional[bool] = None,
num_of_sequences: Optional[int] = 1024,
chars_per_token: Optional[float] = 3.6,
Expand All @@ -126,6 +129,7 @@ def __init__(
- add new args gaudi_config
- use GaudiTrainer instead of Trainer
- cast peft model to bf16.
- add pad_max for static shape
- num_buckets: Number of buckets. > 0 means apply bucketing, <= 0 means no bucketing
"""
if num_buckets > 0:
Expand Down Expand Up @@ -273,6 +277,12 @@ def make_inputs_require_grad(module, input, output):
)
args.max_seq_length = max_seq_length

if pad_max is not None:
warnings.warn(
"You passed a `pad_max` argument to the SFTTrainer, the value you passed will override the one in the `SFTConfig`."
)
args.pad_max = pad_max

if args.max_seq_length is None:
# to overcome some issues with broken tokenizers
max_seq_length = min(tokenizer.model_max_length, 1024)
Expand Down Expand Up @@ -371,6 +381,7 @@ def make_inputs_require_grad(module, input, output):
args.num_of_sequences,
args.chars_per_token,
remove_unused_columns=args.remove_unused_columns if args is not None else True,
pad_max=args.pad_max if args is not None else False,
**args.dataset_kwargs,
)
if eval_dataset is not None:
Expand Down Expand Up @@ -442,6 +453,142 @@ def make_inputs_require_grad(module, input, output):
if callback.__class__.__name__ == "PrinterCallback":
self.callback_handler.pop_callback(callback)

def _prepare_dataset(
self,
dataset,
tokenizer,
packing,
dataset_text_field,
max_seq_length,
formatting_func,
num_of_sequences,
chars_per_token,
remove_unused_columns=True,
pad_max=False,
append_concat_token=True,
add_special_tokens=True,
skip_prepare_dataset=False,
):
"""
Copied from SFTTrainer._prepare_dataset https://github.com/huggingface/trl/blob/v0.9.6/trl/trainer/sft_trainer.py#L477
The only differences are:
- add pad_max for static shape
"""

if dataset is None:
raise ValueError("The dataset should not be None")

if skip_prepare_dataset:
return dataset

# If the dataset is already preprocessed (tokenized), return as-is. Only works if dataset is
# a datasets.Dataset or datasets.IterableDataset -- not for torch Dataset
column_names = (
dataset.column_names if isinstance(dataset, (datasets.Dataset, datasets.IterableDataset)) else None
)
if column_names and "input_ids" in column_names:
if formatting_func is not None:
warnings.warn(
"You passed a dataset that is already processed (contains an `input_ids` field) together with a valid formatting function. Therefore `formatting_func` will be ignored."
)

return dataset

# check if torch dataset / dataloader and do nothing
# see https://github.com/huggingface/trl/pull/1468 for why datasets.IterableDataset needs a separate check
if isinstance(
dataset, (torch.utils.data.IterableDataset, torch.utils.data.Dataset, ConstantLengthDataset)
) and not isinstance(dataset, datasets.IterableDataset):
return dataset

if not packing:
return self._prepare_non_packed_dataloader(
tokenizer,
dataset,
dataset_text_field,
max_seq_length,
formatting_func,
add_special_tokens,
remove_unused_columns,
pad_max,
)

else:
return self._prepare_packed_dataloader(
tokenizer,
dataset,
dataset_text_field,
max_seq_length,
num_of_sequences,
chars_per_token,
formatting_func,
append_concat_token,
add_special_tokens,
)

def _prepare_non_packed_dataloader(
self,
tokenizer,
dataset,
dataset_text_field,
max_seq_length,
formatting_func=None,
add_special_tokens=True,
remove_unused_columns=True,
pad_max=False,
):
"""
Copied from SFTTrainer._prepare_non_packed_dataloader
https://github.com/huggingface/trl/blob/v0.9.6/trl/trainer/sft_trainer.py#L542
The only differences are:
- add pad_max for static shape
"""

use_formatting_func = formatting_func is not None and dataset_text_field is None
self._dataset_sanity_checked = False

# Inspired from: https://huggingface.co/learn/nlp-course/chapter7/6?fw=pt
def tokenize(element):
outputs = tokenizer(
element[dataset_text_field] if not use_formatting_func else formatting_func(element),
add_special_tokens=add_special_tokens,
truncation=True,
padding="max_length" if pad_max else False,
max_length=max_seq_length,
return_overflowing_tokens=False,
return_length=False,
)

if use_formatting_func and not self._dataset_sanity_checked:
if not isinstance(formatting_func(element), list):
raise ValueError(
"The `formatting_func` should return a list of processed strings since it can lead to silent bugs."
)
else:
self._dataset_sanity_checked = True

return {"input_ids": outputs["input_ids"], "attention_mask": outputs["attention_mask"]}

signature_columns = ["input_ids", "labels", "attention_mask"]

extra_columns = list(set(dataset.column_names) - set(signature_columns))

if not remove_unused_columns and len(extra_columns) > 0:
warnings.warn(
"You passed `remove_unused_columns=False` on a non-packed dataset. This might create some issues with the default collator and yield to errors. If you want to "
f"inspect dataset other columns (in this case {extra_columns}), you can subclass `DataCollatorForLanguageModeling` in case you used the default collator and create your own data collator in order to inspect the unused dataset columns."
)

tokenized_dataset = dataset.map(
tokenize,
batched=True,
remove_columns=dataset.column_names if remove_unused_columns else None,
num_proc=self.dataset_num_proc,
batch_size=self.dataset_batch_size,
)

return tokenized_dataset

def _get_buckets(self, sentence_lengths, num_buckets):
return np.unique(
np.percentile(
Expand Down

0 comments on commit 22c6adb

Please sign in to comment.