Skip to content

Commit

Permalink
Update sft_llama2.py to work with the latest API (#1637)
Browse files Browse the repository at this point in the history
* Update sft_llama2.py to work with the latest API

SFTTrainer now takes a STFConfig argument

* Update dpo_llama2.py

* precommit
  • Loading branch information
xianbaoqian authored May 10, 2024
1 parent b8b8978 commit 5aeb752
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
from accelerate import Accelerator
from datasets import Dataset, load_dataset
from peft import LoraConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, TrainingArguments, set_seed
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, set_seed

from trl import DPOTrainer
from trl import DPOConfig, DPOTrainer


# Define and parse arguments.
Expand Down Expand Up @@ -177,7 +177,7 @@ def return_prompt_and_responses(samples) -> Dict[str, str]:
)

# 4. initialize training arguments:
training_args = TrainingArguments(
training_args = DPOConfig(
per_device_train_batch_size=script_args.per_device_train_batch_size,
per_device_eval_batch_size=script_args.per_device_eval_batch_size,
max_steps=script_args.max_steps,
Expand Down
10 changes: 4 additions & 6 deletions examples/research_projects/stack_llama_2/scripts/sft_llama2.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,10 @@
AutoTokenizer,
BitsAndBytesConfig,
HfArgumentParser,
TrainingArguments,
set_seed,
)

from trl import SFTTrainer
from trl import SFTConfig, SFTTrainer
from trl.import_utils import is_npu_available, is_xpu_available
from trl.trainer import ConstantLengthDataset

Expand All @@ -33,7 +32,6 @@ class ScriptArguments:
shuffle_buffer: Optional[int] = field(default=5000, metadata={"help": "the shuffle buffer size"})
seq_length: Optional[int] = field(default=1024, metadata={"help": "the sequence length"})
num_workers: Optional[int] = field(default=4, metadata={"help": "the number of workers"})
packing: Optional[bool] = field(default=True, metadata={"help": "whether to use packing for SFTTrainer"})
use_bnb: Optional[bool] = field(default=True, metadata={"help": "whether to use BitsAndBytes"})

# LoraConfig
Expand All @@ -42,7 +40,7 @@ class ScriptArguments:
lora_r: Optional[int] = field(default=8, metadata={"help": "the lora r parameter"})


parser = HfArgumentParser((ScriptArguments, TrainingArguments))
parser = HfArgumentParser((ScriptArguments, SFTConfig))
script_args, training_args = parser.parse_args_into_dataclasses()
peft_config = LoraConfig(
r=script_args.lora_r,
Expand All @@ -53,7 +51,7 @@ class ScriptArguments:
task_type="CAUSAL_LM",
)

if training_args.group_by_length and script_args.packing:
if training_args.group_by_length and training_args.packing:
raise ValueError("Cannot use both packing and group by length")

# `gradient_checkpointing` was True by default until `1f3314`, but it's actually not used.
Expand Down Expand Up @@ -172,8 +170,8 @@ def create_datasets(tokenizer, args, seed=None):
train_dataset=train_dataset,
eval_dataset=eval_dataset,
peft_config=peft_config,
packing=script_args.packing,
max_seq_length=None,
formatting_func=prepare_sample_text,
tokenizer=tokenizer,
args=training_args,
)
Expand Down

0 comments on commit 5aeb752

Please sign in to comment.