Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multiple calls to .fit #161

Conversation

ravi-mosaicml
Copy link
Contributor

Support multiple calls to .fit for partial training, as discussed in #138. Specifically, this PR:

  • Moves most initialization logic to Trainer.__init__, so it is invoked only once
  • Added num_epochs and num_batches as optional parameters to Trainer.fit
  • Significant refactor of the training loop to support the above
  • Added test cases (and updated the synthetic dataloader to use a deterministic generator, so the model from the tests would be identical)

Support multiple calls to .fit for partial training, as discussed in #138. Specifically, this PR:
* Moves most initialization logic to `Trainer.__init__`, so it is invoked only once
* Added `num_epochs` and `num_batches` as optional parameters to `Trainer.fit`
* Significant refactor of the training loop to support the above
* Added test cases (and updated the synthetic dataloader to use a deterministic generator, so the model from the tests would be identical)
@@ -565,12 +635,89 @@ def eval_subset_num_batches(self):
def eval_subset_num_batches(self, eval_subset_num_batches: Optional[int] = None):
self._eval_subset_num_batches = eval_subset_num_batches

def fit(self):
"""Train and evaluate the model on the provided data."""
def fit(self, num_batches: Optional[int] = None, num_epochs: Optional[int] = None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like that we're moving some params to fit, but I'm conflicted about num_batches. It doesn't tell me anything about the amount of information I'm feeding to the network. An epochs is useful as it represent a full pass through the dataset, train_fraction (training on a subset of the dataset) could also be useful for debugging, learning curves, and active learning situations. I know large NLP models like tokens so perhaps @moinnadeem some opinions about this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

num_batches and num_epochs will be replaced with a duration argument once #146 is implemented. That will support tokens and a fraction of the total training duration.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(I'd prefer not to add that logic here as #146 is next on my agenda)

@ravi-mosaicml
Copy link
Contributor Author

Closing this PR as we need to revisit multiple calls to .fit

@ajaysaini725 ajaysaini725 mentioned this pull request Jan 22, 2022
@Averylamp Averylamp deleted the ravi/trainer_multiple_fit branch March 10, 2022 00:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants