Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DPO trainer supports num_logits_to_keep to save memory #2129

Merged
merged 55 commits into from
Nov 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
72d4647
Support num_logits_to_keep, which computes necessary logits in the fo…
xyangk Sep 20, 2024
f1ea892
Merge branch 'huggingface:main' into main
xyangk Sep 20, 2024
13cb564
update doc
xyangk Sep 20, 2024
89cc2fc
Merge branch 'main' of github.com:xyangk/trl
xyangk Sep 20, 2024
efb12cb
bug fix
xyangk Sep 26, 2024
2348948
Merge branch 'huggingface:main' into main
xyangk Sep 26, 2024
88ccdd0
Merge branch 'main' into keep_necessary_logits
xyangk Sep 26, 2024
1b21865
update
xyangk Sep 26, 2024
03b1b67
check is model supports num_logits_to_keep
xyangk Sep 26, 2024
96a37ae
Merge branch 'main' into keep_necessary_logits
xyangk Sep 26, 2024
bd0490d
Merge branch 'main' into keep_necessary_logits
xyangk Sep 27, 2024
1ede5e7
Merge branch 'huggingface:main' into main
xyangk Sep 27, 2024
802f366
Merge branch 'main' into keep_necessary_logits
kashif Sep 27, 2024
6ca9cbe
ruff format
xyangk Sep 27, 2024
2b240a0
Merge branch 'keep_necessary_logits' of github.com:xyangk/trl into ke…
xyangk Sep 27, 2024
3adf9ab
Merge branch 'huggingface:main' into main
xyangk Sep 27, 2024
5bc196d
precommit
xyangk Sep 27, 2024
2877cff
Merge branch 'main' into keep_necessary_logits
xyangk Sep 27, 2024
1e55540
Merge branch 'huggingface:main' into main
xyangk Sep 27, 2024
15a73e2
Merge branch 'main' into keep_necessary_logits
kashif Sep 27, 2024
f2ac776
update test file
xyangk Sep 29, 2024
4790868
Merge branch 'keep_necessary_logits' of github.com:xyangk/trl into ke…
xyangk Sep 29, 2024
cbe58bb
peft model support
xyangk Sep 29, 2024
725ccf0
test passed
xyangk Sep 29, 2024
e5e0605
Merge branch 'main' into keep_necessary_logits
kashif Oct 6, 2024
4815b75
Merge branch 'huggingface:main' into main
xyangk Oct 7, 2024
273f519
Merge branch 'huggingface:main' into main
xyangk Oct 8, 2024
aaf889a
Merge branch 'main' into keep_necessary_logits
xyangk Oct 8, 2024
fa043e5
Merge branch 'keep_necessary_logits' of github.com:xyangk/trl into ke…
xyangk Oct 8, 2024
aff9955
Merge branch 'main' into keep_necessary_logits
kashif Oct 8, 2024
d385930
Merge branch 'main' into keep_necessary_logits
kashif Oct 11, 2024
ed78dc4
Merge branch 'huggingface:main' into main
xyangk Oct 11, 2024
86b81db
Merge branch 'main' into keep_necessary_logits
xyangk Oct 11, 2024
d1ada50
update
xyangk Oct 11, 2024
6512445
Merge branch 'keep_necessary_logits' of github.com:xyangk/trl into ke…
xyangk Oct 11, 2024
56c5c22
Merge branch 'main' into keep_necessary_logits
xyangk Oct 16, 2024
a75e4a2
Merge branch 'main' of github.com:xyangk/trl
xyangk Oct 23, 2024
c583b57
Merge branch 'main' into keep_necessary_logits
xyangk Oct 23, 2024
56241e0
apply use_num_logits_to_keep
xyangk Oct 23, 2024
dfb2fa9
fix num_logits_to_keep compute bug
xyangk Oct 25, 2024
9e9e2ef
compare all outputs
xyangk Oct 25, 2024
858d404
pytest
xyangk Oct 25, 2024
4e37726
Merge branch 'main' into keep_necessary_logits
xyangk Oct 25, 2024
dc7849a
pass test
xyangk Oct 25, 2024
a3984dc
Merge branch 'keep_necessary_logits' of github.com:xyangk/trl into ke…
xyangk Oct 25, 2024
4debb1d
Merge branch 'main' into keep_necessary_logits
xyangk Oct 28, 2024
1ab1bd9
Merge branch 'main' into keep_necessary_logits
xyangk Oct 31, 2024
664c823
Merge branch 'main' into keep_necessary_logits
xyangk Nov 4, 2024
3e61e4c
Merge branch 'main' into keep_necessary_logits
kashif Nov 6, 2024
70838a7
use check_min_version
xyangk Nov 6, 2024
c4b5c4f
format
xyangk Nov 6, 2024
32ed085
test_dpo_trainer_use_num_logits_to_keep passed
xyangk Nov 6, 2024
77c1fd5
Merge branch 'main' into keep_necessary_logits
kashif Nov 7, 2024
f3c044e
add some comments
xyangk Nov 8, 2024
90324f3
Merge branch 'main' into keep_necessary_logits
kashif Nov 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions tests/test_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1053,6 +1053,87 @@ def test_dpo_loss_js_div_f(self):
)
self.assertTrue(torch.isfinite(losses).cpu().numpy().all())

def test_dpo_trainer_use_num_logits_to_keep(self):
model_id = "trl-internal-testing/tiny-random-LlamaForCausalLM"
tokenizer = AutoTokenizer.from_pretrained(model_id)

model = AutoModelForCausalLM.from_pretrained(model_id)

with tempfile.TemporaryDirectory() as tmp_dir:
training_args = DPOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=3,
remove_unused_columns=False,
gradient_accumulation_steps=1,
learning_rate=9e-1,
eval_strategy="steps",
beta=0.1,
use_num_logits_to_keep=True,
rpo_alpha=0.5,
report_to="none",
)

dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference")

# dpo train lora model with a lora config
trainer = DPOTrainer(
model=model,
ref_model=None,
args=training_args,
tokenizer=tokenizer,
train_dataset=dummy_dataset["train"],
eval_dataset=dummy_dataset["test"],
)

training_args.use_num_logits_to_keep = False
trainer2 = DPOTrainer(
model=model,
ref_model=None,
args=training_args,
tokenizer=tokenizer,
train_dataset=dummy_dataset["train"],
eval_dataset=dummy_dataset["test"],
)

# Fake batch
prompt_input_ids = torch.randint(1, 1000, (2, 10))
chosen_input_ids = torch.randint(1, 1000, (2, 5))
rejected_input_ids = torch.randint(1, 1000, (2, 7))
prompt_attention_mask = torch.ones_like(prompt_input_ids)
chosen_attention_mask = torch.ones_like(chosen_input_ids)
rejected_attention_mask = torch.ones_like(rejected_input_ids)

batch = {
"prompt_input_ids": prompt_input_ids.to(model.device),
"chosen_input_ids": chosen_input_ids.to(model.device),
"rejected_input_ids": rejected_input_ids.to(model.device),
"prompt_attention_mask": prompt_attention_mask.to(model.device),
"chosen_attention_mask": chosen_attention_mask.to(model.device),
"rejected_attention_mask": rejected_attention_mask.to(model.device),
}

output = trainer.concatenated_forward(model, batch)
output2 = trainer2.concatenated_forward(model, batch)

np.testing.assert_allclose(output["nll_loss"].item(), output2["nll_loss"].item(), atol=1e-5)
np.testing.assert_allclose(
output["mean_chosen_logits"].item(), output2["mean_chosen_logits"].item(), atol=1e-5
)
np.testing.assert_allclose(
output["mean_rejected_logits"].item(), output2["mean_rejected_logits"].item(), atol=1e-5
)

for i in range(output["chosen_logps"].shape[0]):
np.testing.assert_allclose(
output["chosen_logps"][i].item(), output2["chosen_logps"][i].item(), atol=1e-5
)
np.testing.assert_allclose(
output["rejected_logps"][i].item(), output2["rejected_logps"][i].item(), atol=1e-5
)

trainer.train()


@require_vision
class DPOVisionTrainerTester(unittest.TestCase):
Expand Down
6 changes: 6 additions & 0 deletions trl/trainer/dpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,11 @@ class DPOConfig(TrainingArguments):
α parameter from the [RPO](https://huggingface.co/papers/2404.19733) paper (v3), which controls the
weighting of the NLL term in the loss. If `None`, no weighting is applied and the loss is the same as the
DPO loss. The paper recommends `rpo_alpha=1.0`.
use_num_logits_to_keep (`bool`, *optional*, defaults to `False`):
If `True`, only a specified number of logits are computed in the forward pass of CausalLM. This can be useful
for saving memory and speeding up training by not computing the logits for all tokens, especially in scenarios
when working with very long prompts where labels are -ignored (-100).
[Read more](https://huggingface.co/docs/transformers/main/model_doc/llama#transformers.LlamaForCausalLM)
"""

learning_rate: float = 1e-6
Expand Down Expand Up @@ -176,6 +181,7 @@ class DPOConfig(TrainingArguments):
ref_model_mixup_alpha: float = 0.9
ref_model_sync_steps: int = 64
rpo_alpha: Optional[float] = None
use_num_logits_to_keep: bool = False

def __post_init__(self):
if self.max_target_length is not None:
Expand Down
33 changes: 29 additions & 4 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
from transformers.models.auto.modeling_auto import MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES
from transformers.trainer_callback import TrainerCallback
from transformers.trainer_utils import EvalLoopOutput
from transformers.utils import is_peft_available
from transformers.utils import check_min_version, is_peft_available
from transformers.utils.deprecation import deprecate_kwarg

from ..data_utils import maybe_apply_chat_template, maybe_extract_prompt
Expand Down Expand Up @@ -396,6 +396,7 @@ def make_inputs_require_grad(module, input, output):
self.truncation_mode = args.truncation_mode
self.max_completion_length = args.max_completion_length
self.precompute_ref_log_probs = args.precompute_ref_log_probs
self.use_num_logits_to_keep = args.use_num_logits_to_keep

# Since ref_logs are precomputed on the first call to get_train/eval_dataloader
# keep track of first called to avoid computation of future calls
Expand Down Expand Up @@ -529,6 +530,11 @@ def make_inputs_require_grad(module, input, output):
)

self.add_callback(SyncRefModelCallback(ref_model=self.ref_model, accelerator=self.accelerator))

# num_logits_to_keep is supported since transformers v4.45.0
if self.use_num_logits_to_keep:
check_min_version("4.45.0")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I missed that the transformers version was already updated. Since the current PR is already merged, I'll create a new PR to remove the version check. Thank you for cathing this!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


if self.loss_type == "bco_pair":
self.running = RunningMoments(self.accelerator)

Expand Down Expand Up @@ -1087,23 +1093,42 @@ def concatenated_forward(self, model: nn.Module, batch: Dict[str, Union[List, to
# Get the first column idx that is all zeros and remove every column after that
empty_cols = torch.sum(attention_mask, dim=0) == 0
first_empty_col = torch.nonzero(empty_cols)[0].item() if empty_cols.any() else attention_mask.size(1)
input_ids = input_ids[:, : first_empty_col]
attention_mask = attention_mask[:, : first_empty_col]
loss_mask = loss_mask[:, : first_empty_col]
input_ids = input_ids[:, :first_empty_col]
attention_mask = attention_mask[:, :first_empty_col]
loss_mask = loss_mask[:, :first_empty_col]

# Truncate right
if self.args.max_length is not None:
input_ids = input_ids[:, : self.args.max_length]
attention_mask = attention_mask[:, : self.args.max_length]
loss_mask = loss_mask[:, : self.args.max_length]

if self.use_num_logits_to_keep:
# Compute num_logits_to_keep based on loss_mask pattern:
# [[0, 0, 0, x, x, x, x],
# [0, 0, 0, x, x, x, 0]]
# ^ start computing logits from here ([:, -(7-3+1):])
first_compute_index = loss_mask.nonzero(as_tuple=True)[1].min()
num_logits_to_keep = loss_mask.shape[1] - first_compute_index
model_kwargs["num_logits_to_keep"] = num_logits_to_keep.item() + 1 # +1 for the first label

outputs = model(input_ids=input_ids, attention_mask=attention_mask, **model_kwargs)

# Offset the logits by one to align with the labels
logits = outputs.logits[:, :-1, :]
labels = input_ids[:, 1:].clone()
loss_mask = loss_mask[:, 1:].bool()

if self.use_num_logits_to_keep:
# Align labels with logits
# logits: -, -, [x2, x3, x4, x5, x6]
# ^ --------- ^ after logits[:, :-1, :]
# labels: [y0, y1, y2, y3, y4, y5, y6]
# ^ --------- ^ with num_logits_to_keep=4, [:, -4:]
# loss_mask: [0, 0, 0, 1, 1, 1, 1]
labels = labels[:, -num_logits_to_keep:]
loss_mask = loss_mask[:, -num_logits_to_keep:]

if logits.shape[:2] != labels.shape[:2]:
# for llava, the returned logits include the image tokens (placed before the text tokens)
seq_len = labels.shape[1]
Expand Down