Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a variant of CPO, SimPO #1703

Merged
merged 12 commits into from
Jun 6, 2024
3 changes: 3 additions & 0 deletions docs/source/cpo_trainer.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ avoid generating adequate, but not perfect translations in Machine Translation (

CPO aims to mitigate two fundamental shortcomings of SFT. First, SFT’s methodology of minimizing the discrepancy between predicted outputs and gold-standard references inherently caps model performance at the quality level of the training data. Secondly, SFT lacks a mechanism to prevent the model from rejecting mistakes in translations. The CPO objective is derived from the DPO objective.

## A Variant of CPO: SimPO
There is also a variant of CPO, [SimPO: Simple Preference Optimization with a Reference-Free Reward](https://arxiv.org/abs/2405.14734), which adds a reward margin and does not use BC regularization. Use the `loss_type="simpo"` in the `CPOConfig` to use this loss.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fe1ixxu thanks for the PR! You implemented SimPO loss elegantly here.

WDYT about changing the documentation to the following:

## SimPO: Regularizing output length

Sometimes, model output length is a confounding factor when doing post-training evaluations because judges may prefer longer outputs. [SimPO: Simple Preference Optimization with a Reference-Free Reward](https://arxiv.org/abs/2405.14734) is an alternative loss that regularizes output length, adds a reward margin, and does not use BC regularization. SimPO loss regularizes output length and the trained model do not over exploit longer output length.

We can easily reuse CPOTrainer and use SimPO by turning on' loss_type= "simpo"` and tuning the `simpo_gamma` in the `CPOConfig` to use this loss. 

This way, we emphasize the potential benefit of SimPO (which is to regularize length).

## Expected dataset format

The CPO trainer expects a format identical to the DPO trainer, which should include three entries. These entries should be named as follows:
Expand Down
2 changes: 2 additions & 0 deletions tests/test_cpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ def _init_dummy_dataset(self):
["t5", "hinge"],
["gpt2", "ipo"],
["t5", "ipo"],
["gpt2", "simpo"],
["t5", "simpo"],
]
)
def test_cpo_trainer(self, name, loss_type):
Expand Down
5 changes: 4 additions & 1 deletion trl/trainer/cpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ class CPOConfig(TrainingArguments):
The type of loss to use. This argument is required if you want to use the default data collator.
label_pad_token_id (`int`, defaults to `-100`):
The label pad token id. This argument is required if you want to use the default data collator.
simpo_gamma (`float`, defaults to `0.5`):
A target reward margin for the SimPO loss, used only when the "simpo" option is enabled.
padding_value (`int`, defaults to `None`):
The padding value if it is different to the tokenizer's pad_token_id.
truncation_mode (`str`, defaults to `keep_end`):
Expand All @@ -64,8 +66,9 @@ class CPOConfig(TrainingArguments):

beta: float = 0.1
label_smoothing: float = 0
loss_type: Literal["sigmoid", "hinge", "ipo", "kto_pair"] = "sigmoid"
loss_type: Literal["sigmoid", "hinge", "ipo", "kto_pair", "simpo"] = "sigmoid"
disable_dropout: bool = True
simpo_gamma: float = 0.5

label_pad_token_id: int = -100
padding_value: int = None
Expand Down
24 changes: 20 additions & 4 deletions trl/trainer/cpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,9 @@ def make_inputs_require_grad(module, input, output):
self.label_smoothing = args.label_smoothing
self.loss_type = args.loss_type

if args.loss_type == "simpo":
self.simpo_gamma = args.simpo_gamma

self._stored_metrics = defaultdict(lambda: defaultdict(list))

# Compute that only on the main process for faster data processing.
Expand Down Expand Up @@ -585,7 +588,16 @@ def cpo_loss(
# The beta is a temperature parameter for the CPO loss, typically something in the range of 0.1 to 0.5.
# We ignore the reference model as beta -> 0. The label_smoothing parameter encodes our uncertainty about the labels and
# calculates a conservative CPO loss.
if self.loss_type == "sigmoid":

if self.loss_type == "simpo":
gamma_logratios = self.simpo_gamma / self.beta
logits = logits - gamma_logratios
# This reduces to Equation 3 from the CPO paper when label_smoothing -> 0.
losses = (
-F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
- F.logsigmoid(-self.beta * logits) * self.label_smoothing
)
elif self.loss_type == "sigmoid":
# This reduces to Equation 3 from the CPO paper when label_smoothing -> 0.
losses = (
-F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
Expand All @@ -598,7 +610,7 @@ def cpo_loss(
losses = (logits - 1 / (2 * self.beta)) ** 2
else:
raise ValueError(
f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'kto_pair']"
f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'kto_pair', 'simpo']"
)

chosen_rewards = self.beta * (policy_chosen_logps.to(self.accelerator.device)).detach()
Expand Down Expand Up @@ -691,12 +703,16 @@ def cross_entropy_loss(logits, labels):
return loss

labels = concatenated_batch["concatenated_labels"].clone()
nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen])

if self.loss_type != "simpo":
nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen])
else:
nll_loss = torch.tensor(0.0).to(self.accelerator.device)

all_logps = self.get_batch_logps(
all_logits,
concatenated_batch["concatenated_labels"],
average_log_prob=self.loss_type == "ipo",
average_log_prob=self.loss_type in ["ipo", "simpo"],
is_encoder_decoder=self.is_encoder_decoder,
label_pad_token_id=self.label_pad_token_id,
)
Expand Down
Loading