-
Notifications
You must be signed in to change notification settings - Fork 433
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Ignore parameters from a checkpoint #497
Labels
enhancement
New (engineering) enhancements, such as features or API changes.
Comments
I think I'm leaning towards #2. I would rather the checkpoint saver save everything, and then selectively load, rather than discard information. I would rather that parameters not being loaded are explicitly listed rather than implicitly ignored. |
ravi-mosaicml
added a commit
that referenced
this issue
Mar 7, 2022
This PR is the first in a series for cleaning up the checkpoint API. One of the prerequesites is storing the seed on the state. Here, only the rank zero seed is stored on state, since only the rank zero state is persisted in a checkpoint. The trainer uses a distributed reduction to share the seed across states, so the same seed will be restored when resuming from checkpointing, even if a seed was not originally specified. This PR ignores the `seed` parameter passed into the trainer when resuming from a checkpoint. For the time being, if a new seed is desired, the `seed` attribute must be removed from the checkpoint state dict. #497 will introduce a cleaner API for this (edge) use case.
ravi-mosaicml
added a commit
that referenced
this issue
Mar 11, 2022
This PR is the first in a series for cleaning up the checkpoint API. One of the prerequisites is storing the seed on the state. Here, only the rank zero seed is stored on state, since only the rank zero state is persisted in a checkpoint. The trainer uses a distributed reduction to share the seed across states, so the same seed will be restored when resuming from checkpointing, even if a seed was not originally specified. This PR ignores the `seed` parameter passed into the trainer when resuming from a checkpoint. For the time being, if a new seed is desired, the `seed` attribute must be removed from the checkpoint state dict. #497 will introduce a cleaner API for this (edge) use case.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Motivation
Some GLUE tasks require mid-training: for tasks with small, noisy datasets:
For #3, we load the checkpoint obtained after completing #2.
This implies that BERT was trained with a task-specific classification head (in code, there is now a
module.model.classifier.weight
andmodule.model.classifier.bias
) that exists that is irrelevant to us. This is because:a) The classification weights are from a task that doesn't matter to us
b) The classification weights are a 3-way classification task, and we now have a 2-way classification task.
Because of B), the matrix sizes will not match up, and training will crash. In code, the checkpoint from B is something like
batch_size x 3
, and our model has a matrix of shapebatch_size x 2
.This motivates ignoring certain parameters from a checkpoint.
Solutions
I see three proposed solutions here:
Before continuing, please decide which one you like the best so I don't bias you.
To unblock myself previously, I have implemented #2: it will delete user-specified weights from the dictionary. This consists of a parameter in checkpointing called
ignore_model_keys
that can be specified with a list of the weights to delete from thestate_dict
. See the code:We should decide:
Once we have decided, I'm happy to code up the pull request!
The text was updated successfully, but these errors were encountered: