From 5aeb752053876cce64f2164a178635db08d96158 Mon Sep 17 00:00:00 2001 From: Tiezhen WANG <38108242+xianbaoqian@users.noreply.github.com> Date: Fri, 10 May 2024 23:19:15 +0800 Subject: [PATCH] Update sft_llama2.py to work with the latest API (#1637) * Update sft_llama2.py to work with the latest API SFTTrainer now takes a STFConfig argument * Update dpo_llama2.py * precommit --- .../stack_llama_2/scripts/dpo_llama2.py | 6 +++--- .../stack_llama_2/scripts/sft_llama2.py | 10 ++++------ 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py b/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py index 08be4f0e9f..19c7939d37 100644 --- a/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py +++ b/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py @@ -7,9 +7,9 @@ from accelerate import Accelerator from datasets import Dataset, load_dataset from peft import LoraConfig -from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, TrainingArguments, set_seed +from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, set_seed -from trl import DPOTrainer +from trl import DPOConfig, DPOTrainer # Define and parse arguments. @@ -177,7 +177,7 @@ def return_prompt_and_responses(samples) -> Dict[str, str]: ) # 4. initialize training arguments: - training_args = TrainingArguments( + training_args = DPOConfig( per_device_train_batch_size=script_args.per_device_train_batch_size, per_device_eval_batch_size=script_args.per_device_eval_batch_size, max_steps=script_args.max_steps, diff --git a/examples/research_projects/stack_llama_2/scripts/sft_llama2.py b/examples/research_projects/stack_llama_2/scripts/sft_llama2.py index 2719954fc2..8ab0a4c6c6 100644 --- a/examples/research_projects/stack_llama_2/scripts/sft_llama2.py +++ b/examples/research_projects/stack_llama_2/scripts/sft_llama2.py @@ -13,11 +13,10 @@ AutoTokenizer, BitsAndBytesConfig, HfArgumentParser, - TrainingArguments, set_seed, ) -from trl import SFTTrainer +from trl import SFTConfig, SFTTrainer from trl.import_utils import is_npu_available, is_xpu_available from trl.trainer import ConstantLengthDataset @@ -33,7 +32,6 @@ class ScriptArguments: shuffle_buffer: Optional[int] = field(default=5000, metadata={"help": "the shuffle buffer size"}) seq_length: Optional[int] = field(default=1024, metadata={"help": "the sequence length"}) num_workers: Optional[int] = field(default=4, metadata={"help": "the number of workers"}) - packing: Optional[bool] = field(default=True, metadata={"help": "whether to use packing for SFTTrainer"}) use_bnb: Optional[bool] = field(default=True, metadata={"help": "whether to use BitsAndBytes"}) # LoraConfig @@ -42,7 +40,7 @@ class ScriptArguments: lora_r: Optional[int] = field(default=8, metadata={"help": "the lora r parameter"}) -parser = HfArgumentParser((ScriptArguments, TrainingArguments)) +parser = HfArgumentParser((ScriptArguments, SFTConfig)) script_args, training_args = parser.parse_args_into_dataclasses() peft_config = LoraConfig( r=script_args.lora_r, @@ -53,7 +51,7 @@ class ScriptArguments: task_type="CAUSAL_LM", ) -if training_args.group_by_length and script_args.packing: +if training_args.group_by_length and training_args.packing: raise ValueError("Cannot use both packing and group by length") # `gradient_checkpointing` was True by default until `1f3314`, but it's actually not used. @@ -172,8 +170,8 @@ def create_datasets(tokenizer, args, seed=None): train_dataset=train_dataset, eval_dataset=eval_dataset, peft_config=peft_config, - packing=script_args.packing, max_seq_length=None, + formatting_func=prepare_sample_text, tokenizer=tokenizer, args=training_args, )