Skip to content

Commit

Permalink
Add multiprocessing in the DPO trainer. (#1286)
Browse files Browse the repository at this point in the history
* Update dpo_trainer.py

Added support for num_proc to tokenize the training dataset.

* Update dpo_trainer.py

added type in the new num_proc variable

* added test case

* add test case

* fix type

---------

Co-authored-by: imraviagrawal <ravi.agrawal@umass.edu>
Co-authored-by: Ravi Agrawal <raviagrawal@Ravis-MacBook-Pro.local>
  • Loading branch information
3 people authored Jan 30, 2024
1 parent ef441ea commit 737d771
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 1 deletion.
36 changes: 36 additions & 0 deletions tests/test_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,42 @@ def test_dpo_trainer_padding_token_is_none(self):

trainer.train()

def test_dpo_trainer_w_dataset_num_proc(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = TrainingArguments(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=3,
remove_unused_columns=False,
gradient_accumulation_steps=1,
learning_rate=9e-1,
evaluation_strategy="steps",
)

dummy_dataset = self._init_dummy_dataset()

tokenizer = AutoTokenizer.from_pretrained(self.model_id)
tokenizer.pad_token = None

with self.assertRaisesRegex(
ValueError,
expected_regex=r"Padding is enabled, but the tokenizer is not configured with a padding token."
r" Explicitly set `tokenizer.pad_token` \(e.g. `tokenizer.pad_token = tokenizer.eos_token`\)"
r" before calling the trainer.",
):
trainer = DPOTrainer(
model=self.model,
ref_model=None,
beta=0.1,
args=training_args,
tokenizer=tokenizer,
train_dataset=dummy_dataset,
eval_dataset=dummy_dataset,
dataset_num_proc=5,
)

trainer.train()

@require_no_wandb
def test_dpo_trainer_generate_during_eval_no_wandb(self):
with tempfile.TemporaryDirectory() as tmp_dir:
Expand Down
6 changes: 5 additions & 1 deletion trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@ class DPOTrainer(Trainer):
precompute_ref_log_probs (`bool`, defaults to `False`):
Flag to precompute reference model log probabilities and evaluation datasets. This is useful if you want to train
without the reference model and reduce the total GPU memory needed.
dataset_num_proc (`Optional[int]`):
The number of workers to use to tokenize the data. Defaults to None.
model_init_kwargs: (`Optional[Dict]`, *optional*):
Dict of Optional kwargs to pass when instantiating the model from a string
ref_model_init_kwargs: (`Optional[Dict]`, *optional*):
Expand Down Expand Up @@ -162,6 +164,7 @@ def __init__(
generate_during_eval: bool = False,
compute_metrics: Optional[Callable[[EvalLoopOutput], Dict]] = None,
precompute_ref_log_probs: bool = False,
dataset_num_proc: int = None,
model_init_kwargs: Optional[Dict] = None,
ref_model_init_kwargs: Optional[Dict] = None,
model_adapter_name: str = None,
Expand Down Expand Up @@ -352,8 +355,9 @@ def make_inputs_require_grad(module, input, output):

self._stored_metrics = defaultdict(lambda: defaultdict(list))

self.dataset_num_proc = dataset_num_proc
# tokenize the dataset
train_dataset = train_dataset.map(self.tokenize_row)
train_dataset = train_dataset.map(self.tokenize_row, num_proc=self.dataset_num_proc)
if eval_dataset is not None:
eval_dataset = eval_dataset.map(self.tokenize_row)

Expand Down

0 comments on commit 737d771

Please sign in to comment.