Skip to content

Commit

Permalink
Refinements
Browse files Browse the repository at this point in the history
  • Loading branch information
regisss committed Jan 22, 2024
1 parent 65d7a68 commit fd7b0e1
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions examples/trl/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,9 @@ class ScriptArguments:
"https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992"
},
)
seed: Optional[int] = field(default=0, metadata={"help": "the seed"})
seed: Optional[int] = field(
default=0, metadata={"help": "Random seed that will be set at the beginning of training."}
)


def get_stack_exchange_paired(
Expand Down Expand Up @@ -131,6 +133,7 @@ def return_prompt_and_responses(samples) -> Dict[str, str]:
if __name__ == "__main__":
parser = HfArgumentParser(ScriptArguments)
script_args = parser.parse_args_into_dataclasses()[0]

# 1. initialize training arguments:
training_args = GaudiTrainingArguments(
per_device_train_batch_size=script_args.per_device_train_batch_size,
Expand All @@ -157,8 +160,10 @@ def return_prompt_and_responses(samples) -> Dict[str, str]:
use_hpu_graphs_for_inference=True,
seed=script_args.seed,
)
# initial seed for reproducible experiments

# Set seed before initializing model.
set_seed(training_args.seed)

# 2. load a pretrained model
model = AutoModelForCausalLM.from_pretrained(
script_args.model_name_or_path,
Expand Down

0 comments on commit fd7b0e1

Please sign in to comment.