Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Floating-point operations logging in trainer #6768

Merged
merged 33 commits into from
Sep 8, 2020
Merged

Floating-point operations logging in trainer #6768

merged 33 commits into from
Sep 8, 2020

Conversation

TevenLeScao
Copy link
Contributor

First of two PRs to implement #4847 :

  • logging loss vs floating-point operations
  • using the results for scaling laws analysis

This directly logs floating-point operations in wandb and comet, and creates a log_history.json file with training metrics. To do so, it adds methods to PretrainedModel to count parameters with and without embeddings, and the number of floating-point operations. It also has a few Trainer fixes, most importantly averaging the eval loss across processes rather than logging the one in process 0, and a bug with checkpoint folder creation.

@codecov
Copy link

codecov bot commented Aug 28, 2020

Codecov Report

Merging #6768 into master will increase coverage by 1.17%.
The diff coverage is 38.57%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #6768      +/-   ##
==========================================
+ Coverage   78.47%   79.65%   +1.17%     
==========================================
  Files         157      157              
  Lines       28569    28625      +56     
==========================================
+ Hits        22420    22800     +380     
+ Misses       6149     5825     -324     
Impacted Files Coverage Δ
src/transformers/trainer.py 51.85% <31.48%> (-1.81%) ⬇️
src/transformers/modeling_utils.py 86.66% <62.50%> (-0.84%) ⬇️
src/transformers/modeling_tf_xlm.py 88.42% <0.00%> (-4.85%) ⬇️
src/transformers/file_utils.py 82.41% <0.00%> (-0.26%) ⬇️
src/transformers/modeling_bart.py 95.56% <0.00%> (+0.17%) ⬆️
src/transformers/configuration_bart.py 94.00% <0.00%> (+4.00%) ⬆️
src/transformers/tokenization_xlnet.py 90.09% <0.00%> (+23.42%) ⬆️
src/transformers/modeling_tf_transfo_xl.py 88.13% <0.00%> (+68.28%) ⬆️
...c/transformers/modeling_tf_transfo_xl_utilities.py 86.00% <0.00%> (+76.00%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 42fddac...4becfac. Read the comment docs.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! Got a few comments on my side.

src/transformers/modeling_utils.py Outdated Show resolved Hide resolved
src/transformers/modeling_utils.py Outdated Show resolved Hide resolved
src/transformers/modeling_utils.py Outdated Show resolved Hide resolved
src/transformers/modeling_utils.py Show resolved Hide resolved
src/transformers/modeling_utils.py Outdated Show resolved Hide resolved
@@ -690,6 +701,12 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D

tr_loss += self.training_step(model, inputs)

try:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make a cleaner test with isinstance(model, nn.DataParallel)?


# Save a trained model and configuration using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`
if not isinstance(self.model, PreTrainedModel):
raise ValueError("Trainer.model appears to not be a PreTrainedModel")

xm.rendezvous("saving_checkpoint")
# Storing the number of floating-point operations that went into the model
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Those 7 lines are duplicated, maybe put them in a private method to refactor a bit?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed, done

concat = concat[:num_total_examples]
return concat

def distributed_broadcast_scalars(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't seem to use self (distributed_concat neither) so maybe those two methods should be functions?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree and will move them out, do you think we should keep a redirection to keep distributed_concat backwards-compatible?

src/transformers/trainer.py Outdated Show resolved Hide resolved
another model, either implement such a method in the model or override this method.

Args:
model (:obj:`nn.Module`):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we use self.model?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, changed it, allows us to save a few lines in the main method too

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can remove the docstring as well

TevenLeScao and others added 6 commits August 31, 2020 16:08
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, left a few comments.

Comment on lines +479 to +480
# in case the model has no config
combined_dict = {**self.args.to_sanitized_dict()}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there an example of a model without a configuration?

Copy link
Contributor Author

@TevenLeScao TevenLeScao Sep 3, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yes, it's something @sgugger mentioned as well - when writing for Trainer we assume that the model is a PretrainedModel, but the one in the test doesn't inherit from PretrainedModel which is why I put this in. @julien-c also liked the idea of Trainer being domain-agnostic (eg not only NLP for example) so I figured might as well put this line in since it isn't expensive. I think in the end it's something we might want to think about since there's a lot of references to model.config (for example if training on TPU, which the test doesn't test for)

Comment on lines +658 to +659
self.total_flos = getattr(model.config, "total_flos", 0)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wouldn't this fail if the model didn't have a config?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the dummy test model doesn't go through it since it doesn't have a method to calculate flos so I didn't catch it! See above, I think we might have to decide whether we want to assume it has a config or not

another model, either implement such a method in the model or override this method.

Args:
model (:obj:`nn.Module`):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can remove the docstring as well

@LysandreJik
Copy link
Member

I think even with domain-agnostic models we'd like to keep the configuration, no? I'm not sure the trainer would behave correctly without a configuration, so if we want to remove the dependency towards configurations, we might as well do it all at once, right?

Would the goal be to have the trainer accept all nn.Modules?

@sgugger
Copy link
Collaborator

sgugger commented Sep 8, 2020

Like agreed upon internally, we will move to Trainer accepting models instantiating a base abstractclass/conforming to some protocol. I think the config will be in the required field but have to work a bit more on this to be sure.

In any case, this is work for a subsequent PR :-)

@LysandreJik LysandreJik merged commit 01d340a into huggingface:master Sep 8, 2020
Zigur pushed a commit to Zigur/transformers that referenced this pull request Oct 26, 2020
* neFLOs calculation, logging, and reloading (huggingface#1)

* testing distributed consecutive batches

* fixed AttributeError from DataParallel

* removed verbosity

* rotate with use_mtime=True

* removed print

* fixed interaction with gradient accumulation

* indent formatting

* distributed neflo counting

* fixed typo

* fixed typo

* mean distributed losses

* exporting log history

* moved a few functions

* floating_point_ops clarification for transformers with parameter-reuse

* code quality

* double import

* made flo estimation more task-agnostic

* only logging flos if computed

* code quality

* unused import

* Update src/transformers/trainer.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/modeling_utils.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Sylvain review

* Update src/transformers/modeling_utils.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* black

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants