-
Notifications
You must be signed in to change notification settings - Fork 433
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Multiple calls to .fit
#161
Multiple calls to .fit
#161
Conversation
Support multiple calls to .fit for partial training, as discussed in #138. Specifically, this PR: * Moves most initialization logic to `Trainer.__init__`, so it is invoked only once * Added `num_epochs` and `num_batches` as optional parameters to `Trainer.fit` * Significant refactor of the training loop to support the above * Added test cases (and updated the synthetic dataloader to use a deterministic generator, so the model from the tests would be identical)
@@ -565,12 +635,89 @@ def eval_subset_num_batches(self): | |||
def eval_subset_num_batches(self, eval_subset_num_batches: Optional[int] = None): | |||
self._eval_subset_num_batches = eval_subset_num_batches | |||
|
|||
def fit(self): | |||
"""Train and evaluate the model on the provided data.""" | |||
def fit(self, num_batches: Optional[int] = None, num_epochs: Optional[int] = None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like that we're moving some params to fit, but I'm conflicted about num_batches
. It doesn't tell me anything about the amount of information I'm feeding to the network. An epochs
is useful as it represent a full pass through the dataset, train_fraction
(training on a subset of the dataset) could also be useful for debugging, learning curves, and active learning situations. I know large NLP models like tokens so perhaps @moinnadeem some opinions about this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
num_batches
and num_epochs
will be replaced with a duration
argument once #146 is implemented. That will support tokens and a fraction of the total training duration.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(I'd prefer not to add that logic here as #146 is next on my agenda)
Closing this PR as we need to revisit multiple calls to .fit |
Support multiple calls to .fit for partial training, as discussed in #138. Specifically, this PR:
Trainer.__init__
, so it is invoked only oncenum_epochs
andnum_batches
as optional parameters toTrainer.fit