-
Notifications
You must be signed in to change notification settings - Fork 27.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Floating-point operations logging in trainer #6768
Conversation
Codecov Report
@@ Coverage Diff @@
## master #6768 +/- ##
==========================================
+ Coverage 78.47% 79.65% +1.17%
==========================================
Files 157 157
Lines 28569 28625 +56
==========================================
+ Hits 22420 22800 +380
+ Misses 6149 5825 -324
Continue to review full report at Codecov.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR! Got a few comments on my side.
src/transformers/trainer.py
Outdated
@@ -690,6 +701,12 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D | |||
|
|||
tr_loss += self.training_step(model, inputs) | |||
|
|||
try: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we make a cleaner test with isinstance(model, nn.DataParallel)
?
src/transformers/trainer.py
Outdated
|
||
# Save a trained model and configuration using `save_pretrained()`. | ||
# They can then be reloaded using `from_pretrained()` | ||
if not isinstance(self.model, PreTrainedModel): | ||
raise ValueError("Trainer.model appears to not be a PreTrainedModel") | ||
|
||
xm.rendezvous("saving_checkpoint") | ||
# Storing the number of floating-point operations that went into the model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Those 7 lines are duplicated, maybe put them in a private method to refactor a bit?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agreed, done
src/transformers/trainer.py
Outdated
concat = concat[:num_total_examples] | ||
return concat | ||
|
||
def distributed_broadcast_scalars( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't seem to use self (distributed_concat neither) so maybe those two methods should be functions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree and will move them out, do you think we should keep a redirection to keep distributed_concat
backwards-compatible?
another model, either implement such a method in the model or override this method. | ||
|
||
Args: | ||
model (:obj:`nn.Module`): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can't we use self.model
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yep, changed it, allows us to save a few lines in the main method too
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can remove the docstring as well
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, left a few comments.
# in case the model has no config | ||
combined_dict = {**self.args.to_sanitized_dict()} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there an example of a model without a configuration?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah yes, it's something @sgugger mentioned as well - when writing for Trainer
we assume that the model is a PretrainedModel
, but the one in the test doesn't inherit from PretrainedModel
which is why I put this in. @julien-c also liked the idea of Trainer
being domain-agnostic (eg not only NLP for example) so I figured might as well put this line in since it isn't expensive. I think in the end it's something we might want to think about since there's a lot of references to model.config
(for example if training on TPU, which the test doesn't test for)
self.total_flos = getattr(model.config, "total_flos", 0) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wouldn't this fail if the model didn't have a config?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the dummy test model doesn't go through it since it doesn't have a method to calculate flos so I didn't catch it! See above, I think we might have to decide whether we want to assume it has a config or not
another model, either implement such a method in the model or override this method. | ||
|
||
Args: | ||
model (:obj:`nn.Module`): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can remove the docstring as well
I think even with domain-agnostic models we'd like to keep the configuration, no? I'm not sure the trainer would behave correctly without a configuration, so if we want to remove the dependency towards configurations, we might as well do it all at once, right? Would the goal be to have the trainer accept all |
Like agreed upon internally, we will move to Trainer accepting models instantiating a base abstractclass/conforming to some protocol. I think the config will be in the required field but have to work a bit more on this to be sure. In any case, this is work for a subsequent PR :-) |
* neFLOs calculation, logging, and reloading (huggingface#1) * testing distributed consecutive batches * fixed AttributeError from DataParallel * removed verbosity * rotate with use_mtime=True * removed print * fixed interaction with gradient accumulation * indent formatting * distributed neflo counting * fixed typo * fixed typo * mean distributed losses * exporting log history * moved a few functions * floating_point_ops clarification for transformers with parameter-reuse * code quality * double import * made flo estimation more task-agnostic * only logging flos if computed * code quality * unused import * Update src/transformers/trainer.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/modeling_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Sylvain review * Update src/transformers/modeling_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * black Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
First of two PRs to implement #4847 :
This directly logs floating-point operations in wandb and comet, and creates a
log_history.json
file with training metrics. To do so, it adds methods toPretrainedModel
to count parameters with and without embeddings, and the number of floating-point operations. It also has a few Trainer fixes, most importantly averaging the eval loss across processes rather than logging the one in process 0, and a bug with checkpoint folder creation.