Skip to content

Commit

Permalink
Trl upgrade (#1245)
Browse files Browse the repository at this point in the history
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
  • Loading branch information
sywangyi authored Sep 16, 2024
1 parent 520c875 commit 1a8ad12
Show file tree
Hide file tree
Showing 11 changed files with 689 additions and 184 deletions.
11 changes: 8 additions & 3 deletions examples/trl/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
is_deepspeed_available,
)

from optimum.habana import GaudiConfig, GaudiTrainingArguments
from optimum.habana.trl import GaudiDPOTrainer
from optimum.habana import GaudiConfig
from optimum.habana.trl import GaudiDPOConfig, GaudiDPOTrainer
from optimum.habana.utils import set_seed


Expand Down Expand Up @@ -48,6 +48,9 @@ class ScriptArguments:
gradient_checkpointing: Optional[bool] = field(
default=False, metadata={"help": "whether to use gradient checkpointing"}
)
gradient_checkpointing_use_reentrant: Optional[bool] = field(
default=False, metadata={"help": "whether to use reentrant for gradient checkpointing"}
)

lora_alpha: Optional[float] = field(default=16, metadata={"help": "the lora alpha parameter"})
lora_dropout: Optional[float] = field(default=0.05, metadata={"help": "the lora dropout parameter"})
Expand Down Expand Up @@ -140,7 +143,7 @@ def return_prompt_and_responses(samples) -> Dict[str, str]:
script_args = parser.parse_args_into_dataclasses()[0]

# 1. initialize training arguments:
training_args = GaudiTrainingArguments(
training_args = GaudiDPOConfig(
per_device_train_batch_size=script_args.per_device_train_batch_size,
per_device_eval_batch_size=script_args.per_device_eval_batch_size,
max_steps=script_args.max_steps,
Expand All @@ -159,6 +162,7 @@ def return_prompt_and_responses(samples) -> Dict[str, str]:
bf16=True,
remove_unused_columns=False,
run_name="dpo_llama2",
gradient_checkpointing_kwargs={"use_reentrant": script_args.gradient_checkpointing_use_reentrant},
use_habana=True,
use_lazy_mode=True,
use_hpu_graphs_for_training=not script_args.gradient_checkpointing and (not script_args.deepspeed),
Expand Down Expand Up @@ -246,6 +250,7 @@ def return_prompt_and_responses(samples) -> Dict[str, str]:
peft_config=peft_config,
max_prompt_length=script_args.max_prompt_length,
max_length=script_args.max_length,
force_use_ref_model=True,
)

# 6. train
Expand Down
4 changes: 2 additions & 2 deletions examples/trl/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
trl == 0.8.6
peft == 0.6.2
trl == 0.9.6
peft == 0.12.0
datasets == 2.19.2
tyro
evaluate
Expand Down
12 changes: 4 additions & 8 deletions examples/trl/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
is_deepspeed_available,
)

from optimum.habana import GaudiConfig, GaudiTrainingArguments
from optimum.habana.trl import GaudiSFTTrainer
from optimum.habana import GaudiConfig
from optimum.habana.trl import GaudiSFTConfig, GaudiSFTTrainer
from optimum.habana.utils import set_seed


Expand All @@ -33,9 +33,7 @@ class ScriptArguments:
size_valid_set: Optional[int] = field(default=4000, metadata={"help": "the size of the validation set"})
streaming: Optional[bool] = field(default=True, metadata={"help": "whether to stream the dataset"})
shuffle_buffer: Optional[int] = field(default=5000, metadata={"help": "the shuffle buffer size"})
max_seq_length: Optional[int] = field(default=1024, metadata={"help": "the max sequence length"})
num_workers: Optional[int] = field(default=4, metadata={"help": "the number of workers"})
packing: Optional[bool] = field(default=True, metadata={"help": "whether to use packing for SFTTrainer"})
validation_split_percentage: Optional[int] = field(
default=5,
metadata={
Expand Down Expand Up @@ -73,7 +71,7 @@ class ScriptArguments:


if __name__ == "__main__":
parser = HfArgumentParser((ScriptArguments, GaudiTrainingArguments))
parser = HfArgumentParser((ScriptArguments, GaudiSFTConfig))
script_args, training_args = parser.parse_args_into_dataclasses()
if script_args.use_peft:
peft_config = LoraConfig(
Expand All @@ -87,7 +85,7 @@ class ScriptArguments:
else:
peft_config = None

if training_args.group_by_length and script_args.packing:
if training_args.group_by_length and training_args.packing:
raise ValueError("Cannot use both packing and group by length")

set_seed(training_args.seed)
Expand Down Expand Up @@ -187,8 +185,6 @@ def create_datasets(tokenizer, args, seed=None):
train_dataset=train_dataset,
eval_dataset=eval_dataset,
peft_config=peft_config,
packing=script_args.packing,
max_seq_length=script_args.max_seq_length,
tokenizer=tokenizer,
args=training_args,
formatting_func=formatting_func,
Expand Down
2 changes: 2 additions & 0 deletions optimum/habana/trl/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from .models.modeling_base import adapt_PreTrainedModelWrapper_to_gaudi
from .models.modeling_sd_base import GaudiDefaultDDPOStableDiffusionPipeline
from .trainer.ddpo_trainer import GaudiDDPOTrainer
from .trainer.dpo_config import GaudiDPOConfig
from .trainer.dpo_trainer import GaudiDPOTrainer
from .trainer.ppo_config import GaudiPPOConfig
from .trainer.ppo_trainer import GaudiPPOTrainer
from .trainer.reward_trainer import GaudiRewardTrainer, RewardDataCollatorWithPadding
from .trainer.sft_config import GaudiSFTConfig
from .trainer.sft_trainer import GaudiSFTTrainer
2 changes: 2 additions & 0 deletions optimum/habana/trl/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,5 @@
from .reward_trainer import GaudiRewardTrainer, RewardDataCollatorWithPadding

from .ddpo_trainer import GaudiDDPOTrainer
from .dpo_config import GaudiDPOConfig
from .sft_config import GaudiSFTConfig
62 changes: 62 additions & 0 deletions optimum/habana/trl/trainer/dpo_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Dict, Literal, Optional

from trl.trainer.dpo_config import FDivergenceType

from ... import GaudiTrainingArguments


@dataclass
class GaudiDPOConfig(GaudiTrainingArguments):
r"""
Initialize GaudiDPOConfig.
Adapted from https://github.com/huggingface/trl/blob/v0.9.6/trl/trainer/dpo_config.py#L33
- inherit from GaudiTrainingArguments
"""

beta: float = 0.1
label_smoothing: float = 0
loss_type: Literal[
"sigmoid", "hinge", "ipo", "bco_pair", "sppo_hard", "nca_pair", "robust", "aot", "aot_pair", "exo_pair"
] = "sigmoid"
label_pad_token_id: int = -100
padding_value: Optional[int] = None
truncation_mode: str = "keep_end"
max_length: Optional[int] = None
max_prompt_length: Optional[int] = None
max_target_length: Optional[int] = None
is_encoder_decoder: Optional[bool] = None
disable_dropout: bool = True
generate_during_eval: bool = False
precompute_ref_log_probs: bool = False
dataset_num_proc: Optional[int] = None
model_init_kwargs: Optional[Dict] = None
ref_model_init_kwargs: Optional[Dict] = None
model_adapter_name: Optional[str] = None
ref_adapter_name: Optional[str] = None
reference_free: bool = False
force_use_ref_model: bool = False
f_divergence_type: Optional[FDivergenceType] = FDivergenceType.REVERSE_KL
f_alpha_divergence_coef: Optional[float] = 1.0
sync_ref_model: bool = False
ref_model_mixup_alpha: float = 0.9
ref_model_sync_steps: int = 64
rpo_alpha: Optional[float] = None

def __post_init__(self):
if self.loss_type == "kto_pair":
raise ValueError("Support for kto_pair has been removed in DPOTrainer. Please use KTOTrainer.")
return super().__post_init__()
Loading

0 comments on commit 1a8ad12

Please sign in to comment.