Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ignore parameters from a checkpoint #497

Closed
moinnadeem opened this issue Feb 16, 2022 · 1 comment
Closed

Ignore parameters from a checkpoint #497

moinnadeem opened this issue Feb 16, 2022 · 1 comment
Labels
enhancement New (engineering) enhancements, such as features or API changes.

Comments

@moinnadeem
Copy link
Contributor

moinnadeem commented Feb 16, 2022

Motivation

Some GLUE tasks require mid-training: for tasks with small, noisy datasets:

  1. First, pre-train BERT on some generic language modeling corpus. (ie. Wikipedia)
  2. Then, fine-tune BERT on a task-specific corpus that helps logical inference. For instance, fine-tune BERT on a 3-way classification task on a large, labeled dataset.
  3. Finally, fine-tune BERT on your small, relevant task-specific corpus. For instance, a 2-way classification task.

For #3, we load the checkpoint obtained after completing #2.

This implies that BERT was trained with a task-specific classification head (in code, there is now a module.model.classifier.weight and module.model.classifier.bias) that exists that is irrelevant to us. This is because:
a) The classification weights are from a task that doesn't matter to us
b) The classification weights are a 3-way classification task, and we now have a 2-way classification task.

Because of B), the matrix sizes will not match up, and training will crash. In code, the checkpoint from B is something like batch_size x 3, and our model has a matrix of shape batch_size x 2.

This motivates ignoring certain parameters from a checkpoint.

Solutions

I see three proposed solutions here:

  1. As a part of our checkpointing code, we can decide to ignore certain user-specfied parameters when saving the checkpoint.
  2. As a part of our checkpoint loading code, we can ignore certain user-specfied parameters when loading the checkpoint.
  3. As a part of checkpoint loading, for any keys with the same name, we can ignore any checkpoint parameters that have different shapes than the model weights, and print a warning.

Before continuing, please decide which one you like the best so I don't bias you.

To unblock myself previously, I have implemented #2: it will delete user-specified weights from the dictionary. This consists of a parameter in checkpointing called ignore_model_keys that can be specified with a list of the weights to delete from the state_dict. See the code:

We should decide:

  1. Which solution do we prefer?
  2. What does the API look like for our solution? If it is # 2, are we happy with the API and implementation that I have come up with?

Once we have decided, I'm happy to code up the pull request!

@moinnadeem moinnadeem added the enhancement New (engineering) enhancements, such as features or API changes. label Feb 16, 2022
@ravi-mosaicml
Copy link
Contributor

I think I'm leaning towards #2. I would rather the checkpoint saver save everything, and then selectively load, rather than discard information. I would rather that parameters not being loaded are explicitly listed rather than implicitly ignored.

@ravi-mosaicml ravi-mosaicml modified the milestones: v0.5, v0.4.2 Feb 28, 2022
ravi-mosaicml added a commit that referenced this issue Mar 7, 2022
This PR is the first in a series for cleaning up the checkpoint API. One of the prerequesites is storing the seed on the state.
Here, only the rank zero seed is stored on state, since only the rank zero state is persisted in a checkpoint. The trainer uses a distributed reduction to share the seed across states, so the same seed will be restored when resuming from checkpointing, even if a seed was not originally specified.

This PR ignores the `seed` parameter passed into the trainer when resuming from a checkpoint. For the time being, if a new seed is desired, the `seed` attribute must be removed from the checkpoint state dict. #497 will introduce a cleaner API for this (edge) use case.
ravi-mosaicml added a commit that referenced this issue Mar 11, 2022
This PR is the first in a series for cleaning up the checkpoint API. One of the prerequisites is storing the seed on the state.
Here, only the rank zero seed is stored on state, since only the rank zero state is persisted in a checkpoint. The trainer uses a distributed reduction to share the seed across states, so the same seed will be restored when resuming from checkpointing, even if a seed was not originally specified.

This PR ignores the `seed` parameter passed into the trainer when resuming from a checkpoint. For the time being, if a new seed is desired, the `seed` attribute must be removed from the checkpoint state dict. #497 will introduce a cleaner API for this (edge) use case.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New (engineering) enhancements, such as features or API changes.
Projects
None yet
Development

No branches or pull requests

2 participants