Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding SimPO to TRL #1725

Open
wants to merge 25 commits into
base: main
Choose a base branch
from
Open

Adding SimPO to TRL #1725

wants to merge 25 commits into from

Conversation

yumeng5
Copy link

@yumeng5 yumeng5 commented Jun 11, 2024

Hello,

This PR adds SimPO implementations to TRL.

The official codebase: https://github.com/princeton-nlp/SimPO

Thank you!
Yu

@AIR-hl
Copy link
Contributor

AIR-hl commented Jun 12, 2024

Hello,

This PR adds SimPO implementations to TRL.

The official codebase: https://github.com/princeton-nlp/SimPO

Thank you! Yu

@yumeng5 hi! simpo has been implemented in cpotrainer but not yet released

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@vwxyzjn
Copy link
Contributor

vwxyzjn commented Jun 24, 2024

Hi @yumeng5, we have added SimPO to this PR: #1703. What do you think? We are also happy to incorporate some of the documentation you have in this PR if #1703 is a good implementation of SimPO.

@xiamengzhou
Copy link

@vwxyzjn Hi! Thank you for reaching out. We have carefully considered the integration of CPO and SimPO and have a few concerns:

Hyperparameter Differences: Although CPO and SimPO share some hyperparameters, the optimal values can vary significantly. For example, beta tends to be much larger in SimPO compared to other policy optimization algorithms.

Additional Hyperparameter for SimPO: SimPO introduces an extra hyperparameter, the gamma_beta ratio, which plays a crucial role in determining the appropriate target reward margin. Unlike CPO, which implements this as a constant gamma, we find that tuning the gamma_beta ratio in SimPO is more effective for reaching an optimal point.

Our experiments show that the optimal values for these hyperparameters, as well as the learning rate, can differ significantly from those used in DPO, SLiC-HF, or CPO, despite the similarities in their formulations. Therefore, we think it's better to develop a separate Trainer for SimPO to better highlight its unique requirements for users. What do you think?

@vwxyzjn
Copy link
Contributor

vwxyzjn commented Jun 25, 2024

Hi @xiamengzhou, there are excellent reasons, and they make sense. However, I do have reservations about bloating the codebase from the maintenance point of view.

Maybe two possibilities are going forward:

  1. better documentation and example: for example, you could add a simpo.py in the examples/scripts folder and by default, set the hyperparameters that make sense for simpo like here (
    https://github.com/yumeng5/trl/blob/223ce737d651a78a9b54ee1a4472fc4e4eb61760/examples/scripts/simpo.py#L19-L32). Most people use examples when playing with the trainers anyway, so you can ensure that people use the right beta and so on. Then, in the documentation you can explain that SimPO uses different optimal hyperparameters and recommend the users.

  2. do a very light wrapper of CPOTrainer. For example, instead of inheriting from TrainingArgumetns, maybe inherit from CPOConfig,, and you can name it SimPOConfig. That way you can override the default beta values. Similarly, you can inherit from CPOTrainer.

@xiamengzhou
Copy link

@vwxyzjn Thanks for the response! Your suggestions make sense to me and we will revise the PR accordingly :) Will ping you here once we are done.

@yumeng5
Copy link
Author

yumeng5 commented Jul 10, 2024

Hi @vwxyzjn

Thanks again for your suggestions. We have modified the PR by doing a wrapper of the CPOTrainer. Would you mind taking a look and letting us know if that looks good to you?

Thanks!

import wandb


class SimPOTrainer(CPOTrainer):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer SimPOTrainer to inherit Trainer. It's more in line with our single-file philosophy.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see I'm contradicting #1725 (comment) in a way.

What about this:

    def __init__(
        self,
        model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
        args: Optional[SimPOConfig] = None,
        data_collator: Optional[DataCollator] = None,
        train_dataset: Optional[Dataset] = None,
        eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
        tokenizer: Optional[PreTrainedTokenizerBase] = None,
        model_init: Optional[Callable[[], PreTrainedModel]] = None,
        callbacks: Optional[List[TrainerCallback]] = None,
        optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
        preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
        peft_config: Optional[Dict] = None,
        compute_metrics: Optional[Callable[[EvalLoopOutput], Dict]] = None,
    ):
    super().__init__(
        mode=model,
        args=args
        data_collator=data_collator
        train_dataset=train_dataset
        eval_dataset=eval_dataset
        tokenizer=tokenizer
        model_init=model_init
        callbacks=callbacks
        optimizers=optimizers
        preprocess_logits_for_metrics=preprocess_logits_for_metrics
        peft_config=peft_config
        compute_metrics=compute_metrics
    )
    self.gamma_beta_ratio = args.gamma_beta_ratio
    self.sft_weight = args.sft_weight

Instead of the full init?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for reviewing our PR!

Calling the init of CPOTrainer will check for some loss_types (e.g. "ipo", "kto_pair") that are not supported in SimPOTrainer. Also, the CPOTrainer's init contains some error messages that print out "CPOTrainer", which could be confusing when the user is running SimPOTrainer, so I wrote a separate (full) init function for SimPOTrainer. Would love to know what you think!

if args.loss_type in ["hinge", "ipo"] and args.label_smoothing > 0:
warnings.warn(
"You are using a loss type that does not support label smoothing. Ignoring label_smoothing parameter."
)
if args.loss_type == "kto_pair":
raise ValueError("Support for kto_pair has been removed in CPOTrainer. Please use KTOTrainer.")

Thank you!

Copy link
Member

@qgallouedec qgallouedec Aug 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your comment makes sense. At this point, I see two solutions that both seem good to me:

I feel we're in the middle of these two solutions here. So I think that even if the error messages are misleading (CPO trainer instead of SimPO), this won't be a problem for the user.

@yumeng5
Copy link
Author

yumeng5 commented Jul 23, 2024

Hi @qgallouedec

Thanks again for your help with reviewing our PR! Let me know if any further revision is needed!

Thanks,
Yu

@qgallouedec
Copy link
Member

Hey @yumeng5 I made some new remarks, and there are still the previous ones to address. Don't hesitate to ping me if you need help.

yumeng5 and others added 8 commits August 7, 2024 16:07
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
@yumeng5
Copy link
Author

yumeng5 commented Aug 7, 2024

Hi @qgallouedec

Thanks a lot for your comprehensive review! Following your suggestion, I'm updating the SimPO trainer to be a complete one, which also follows the current setup for other trl trainers. I think I have committed all the changes you made previously. Let me know if they look good to you!

Thanks,
Yu

@qgallouedec qgallouedec self-assigned this Aug 18, 2024
@yumeng5
Copy link
Author

yumeng5 commented Oct 19, 2024

Hi @qgallouedec @lewtun

I noticed that the PR has remained open for a while -- just wanted to check in and see if there's anything I can do to help! Thanks again for your efforts!

Best,
Yu

@cchenv
Copy link

cchenv commented Dec 10, 2024

Hi @yumeng5 thanks for adding your implementation to TRL! I just have a few quick questions.

  1. The existing implementation uses CPOTrainer for training SimPO models. However, by using this approach, I was not able to reproduce those model checkpoints or metrics linked on your Princeton Github: https://github.com/princeton-nlp/SimPO/tree/main Even though I tried to use the same hyperparameters and dataset, the trained models are so much worse, e.g., they produce gibberish quite frequently. I'm wondering if this PR is a better reflection of your original work.
  2. In your original implementation, there is a special preprocessing step by keeping only the model/assistant message in the chosen/rejected field: https://github.com/princeton-nlp/SimPO/blob/main/scripts/run_simpo.py#L89-L99. I also see this is the case for this PR. So I'm wondering if the users need to provide their own preprocessing given a dataset where the chosen and rejected fields contain both user (i.e., prompt) and assistant messages (e.g., https://huggingface.co/datasets/trl-lib/ultrafeedback_binarized?row=0)?

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants