-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding SimPO to TRL #1725
base: main
Are you sure you want to change the base?
Adding SimPO to TRL #1725
Conversation
@yumeng5 hi! |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
@vwxyzjn Hi! Thank you for reaching out. We have carefully considered the integration of CPO and SimPO and have a few concerns: Hyperparameter Differences: Although CPO and SimPO share some hyperparameters, the optimal values can vary significantly. For example, beta tends to be much larger in SimPO compared to other policy optimization algorithms. Additional Hyperparameter for SimPO: SimPO introduces an extra hyperparameter, the gamma_beta ratio, which plays a crucial role in determining the appropriate target reward margin. Unlike CPO, which implements this as a constant gamma, we find that tuning the gamma_beta ratio in SimPO is more effective for reaching an optimal point. Our experiments show that the optimal values for these hyperparameters, as well as the learning rate, can differ significantly from those used in DPO, SLiC-HF, or CPO, despite the similarities in their formulations. Therefore, we think it's better to develop a separate Trainer for SimPO to better highlight its unique requirements for users. What do you think? |
Hi @xiamengzhou, there are excellent reasons, and they make sense. However, I do have reservations about bloating the codebase from the maintenance point of view. Maybe two possibilities are going forward:
|
@vwxyzjn Thanks for the response! Your suggestions make sense to me and we will revise the PR accordingly :) Will ping you here once we are done. |
Hi @vwxyzjn Thanks again for your suggestions. We have modified the PR by doing a wrapper of the CPOTrainer. Would you mind taking a look and letting us know if that looks good to you? Thanks! |
trl/trainer/simpo_trainer.py
Outdated
import wandb | ||
|
||
|
||
class SimPOTrainer(CPOTrainer): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd prefer SimPOTrainer
to inherit Trainer
. It's more in line with our single-file philosophy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, I see I'm contradicting #1725 (comment) in a way.
What about this:
def __init__(
self,
model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
args: Optional[SimPOConfig] = None,
data_collator: Optional[DataCollator] = None,
train_dataset: Optional[Dataset] = None,
eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
tokenizer: Optional[PreTrainedTokenizerBase] = None,
model_init: Optional[Callable[[], PreTrainedModel]] = None,
callbacks: Optional[List[TrainerCallback]] = None,
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
peft_config: Optional[Dict] = None,
compute_metrics: Optional[Callable[[EvalLoopOutput], Dict]] = None,
):
super().__init__(
mode=model,
args=args
data_collator=data_collator
train_dataset=train_dataset
eval_dataset=eval_dataset
tokenizer=tokenizer
model_init=model_init
callbacks=callbacks
optimizers=optimizers
preprocess_logits_for_metrics=preprocess_logits_for_metrics
peft_config=peft_config
compute_metrics=compute_metrics
)
self.gamma_beta_ratio = args.gamma_beta_ratio
self.sft_weight = args.sft_weight
Instead of the full init?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for reviewing our PR!
Calling the init of CPOTrainer will check for some loss_types (e.g. "ipo", "kto_pair") that are not supported in SimPOTrainer. Also, the CPOTrainer's init contains some error messages that print out "CPOTrainer", which could be confusing when the user is running SimPOTrainer, so I wrote a separate (full) init function for SimPOTrainer. Would love to know what you think!
trl/trl/trainer/cpo_trainer.py
Lines 262 to 267 in 9e9dc96
if args.loss_type in ["hinge", "ipo"] and args.label_smoothing > 0: | |
warnings.warn( | |
"You are using a loss type that does not support label smoothing. Ignoring label_smoothing parameter." | |
) | |
if args.loss_type == "kto_pair": | |
raise ValueError("Support for kto_pair has been removed in CPOTrainer. Please use KTOTrainer.") |
Thank you!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Your comment makes sense. At this point, I see two solutions that both seem good to me:
- Rewrite a complete trainer
- Make a very light subclass (as suggested here: Adding SimPO to TRL #1725 (comment))
I feel we're in the middle of these two solutions here. So I think that even if the error messages are misleading (CPO trainer instead of SimPO), this won't be a problem for the user.
Hi @qgallouedec Thanks again for your help with reviewing our PR! Let me know if any further revision is needed! Thanks, |
Hey @yumeng5 I made some new remarks, and there are still the previous ones to address. Don't hesitate to ping me if you need help. |
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Hi @qgallouedec Thanks a lot for your comprehensive review! Following your suggestion, I'm updating the SimPO trainer to be a complete one, which also follows the current setup for other trl trainers. I think I have committed all the changes you made previously. Let me know if they look good to you! Thanks, |
I noticed that the PR has remained open for a while -- just wanted to check in and see if there's anything I can do to help! Thanks again for your efforts! Best, |
Hi @yumeng5 thanks for adding your implementation to TRL! I just have a few quick questions.
|
Hello,
This PR adds SimPO implementations to TRL.
The official codebase: https://github.com/princeton-nlp/SimPO
Thank you!
Yu