Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a variant of CPO, SimPO #1703

Merged
merged 12 commits into from
Jun 6, 2024
2 changes: 2 additions & 0 deletions docs/source/cpo_trainer.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ The [RSO](https://arxiv.org/abs/2309.06657) authors propose to use a hinge loss

The [IPO](https://arxiv.org/abs/2310.12036) authors provide a deeper theoretical understanding of the CPO algorithms and identify an issue with overfitting and propose an alternative loss which can be used via the `loss_type="ipo"` argument to the trainer. Note that the `beta` parameter is the reciprocal of the gap between the log-likelihood ratios of the chosen vs the rejected completion pair and thus the smaller the `beta` the larger this gaps is. As per the paper the loss is averaged over log-likelihoods of the completion (unlike CPO which is summed only).

The [SimPO](https://arxiv.org/abs/2405.14734) is an alternative loss that adds a reward margin, allows for length normalization, and does not use BC regularization. To use this loss, we can use SimPO easily by turning on`loss_type="simpo"` in the `CPOConfig`.


## Logging

Expand Down
2 changes: 2 additions & 0 deletions tests/test_cpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ def _init_dummy_dataset(self):
["t5", "hinge"],
["gpt2", "ipo"],
["t5", "ipo"],
["gpt2", "simpo"],
["t5", "simpo"],
]
)
def test_cpo_trainer(self, name, loss_type):
Expand Down
5 changes: 4 additions & 1 deletion trl/trainer/cpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ class CPOConfig(TrainingArguments):
The type of loss to use. This argument is required if you want to use the default data collator.
label_pad_token_id (`int`, defaults to `-100`):
The label pad token id. This argument is required if you want to use the default data collator.
simpo_gamma (`float`, defaults to `0.5`):
A target reward margin for the SimPO loss, used only when the "simpo" option is enabled.
padding_value (`int`, defaults to `None`):
The padding value if it is different to the tokenizer's pad_token_id.
truncation_mode (`str`, defaults to `keep_end`):
Expand All @@ -64,8 +66,9 @@ class CPOConfig(TrainingArguments):

beta: float = 0.1
label_smoothing: float = 0
loss_type: Literal["sigmoid", "hinge", "ipo", "kto_pair"] = "sigmoid"
loss_type: Literal["sigmoid", "hinge", "ipo", "kto_pair", "simpo"] = "sigmoid"
disable_dropout: bool = True
simpo_gamma: float = 0.5

label_pad_token_id: int = -100
padding_value: int = None
Expand Down
24 changes: 20 additions & 4 deletions trl/trainer/cpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,9 @@ def make_inputs_require_grad(module, input, output):
self.label_smoothing = args.label_smoothing
self.loss_type = args.loss_type

if args.loss_type == "simpo":
self.simpo_gamma = args.simpo_gamma

self._stored_metrics = defaultdict(lambda: defaultdict(list))

# Compute that only on the main process for faster data processing.
Expand Down Expand Up @@ -585,7 +588,16 @@ def cpo_loss(
# The beta is a temperature parameter for the CPO loss, typically something in the range of 0.1 to 0.5.
# We ignore the reference model as beta -> 0. The label_smoothing parameter encodes our uncertainty about the labels and
# calculates a conservative CPO loss.
if self.loss_type == "sigmoid":

if self.loss_type == "simpo":
gamma_logratios = self.simpo_gamma / self.beta
logits = logits - gamma_logratios
# This reduces to Equation 3 from the CPO paper when label_smoothing -> 0.
losses = (
-F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
- F.logsigmoid(-self.beta * logits) * self.label_smoothing
)
elif self.loss_type == "sigmoid":
# This reduces to Equation 3 from the CPO paper when label_smoothing -> 0.
losses = (
-F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
Expand All @@ -598,7 +610,7 @@ def cpo_loss(
losses = (logits - 1 / (2 * self.beta)) ** 2
else:
raise ValueError(
f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'kto_pair']"
f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'kto_pair', 'simpo']"
)

chosen_rewards = self.beta * (policy_chosen_logps.to(self.accelerator.device)).detach()
Expand Down Expand Up @@ -691,12 +703,16 @@ def cross_entropy_loss(logits, labels):
return loss

labels = concatenated_batch["concatenated_labels"].clone()
nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen])

if self.loss_type != "simpo":
nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen])
else:
nll_loss = torch.tensor(0.0).to(self.accelerator.device)

all_logps = self.get_batch_logps(
all_logits,
concatenated_batch["concatenated_labels"],
average_log_prob=self.loss_type == "ipo",
average_log_prob=self.loss_type in ["ipo", "simpo"],
is_encoder_decoder=self.is_encoder_decoder,
label_pad_token_id=self.label_pad_token_id,
)
Expand Down
Loading