Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Floating-point operations logging in trainer #6768

Merged
merged 33 commits into from
Sep 8, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
9ee591e
neFLOs calculation, logging, and reloading (#1)
TevenLeScao Jul 28, 2020
a49f2aa
Merge branch 'master' of https://github.com/huggingface/transformers
TevenLeScao Aug 3, 2020
b50d3e1
testing distributed consecutive batches
TevenLeScao Aug 3, 2020
6818ed2
fixed AttributeError from DataParallel
TevenLeScao Aug 3, 2020
5324678
removed verbosity
TevenLeScao Aug 3, 2020
2636bb8
rotate with use_mtime=True
TevenLeScao Aug 3, 2020
04e471b
removed print
TevenLeScao Aug 4, 2020
f78de89
Merge branch 'master' of https://github.com/huggingface/transformers
TevenLeScao Aug 5, 2020
9e7c05a
fixed interaction with gradient accumulation
TevenLeScao Aug 6, 2020
8def613
indent formatting
TevenLeScao Aug 7, 2020
52635d6
Merged with comet integration PR
TevenLeScao Aug 24, 2020
7b8c0ce
nlp-trainer integration merge
TevenLeScao Aug 24, 2020
245df7c
Merge branch 'master' of https://github.com/huggingface/transformers
TevenLeScao Aug 24, 2020
70f919f
distributed neflo counting
TevenLeScao Aug 26, 2020
349e916
fixed typo
TevenLeScao Aug 26, 2020
9cc578d
fixed typo
TevenLeScao Aug 26, 2020
03fe015
mean distributed losses
TevenLeScao Aug 26, 2020
fa43ae1
exporting log history
TevenLeScao Aug 27, 2020
e7a249f
moved a few functions
TevenLeScao Aug 27, 2020
45f5fcb
floating_point_ops clarification for transformers with parameter-reuse
TevenLeScao Aug 27, 2020
ab49c08
Merged with hyperparam change
TevenLeScao Aug 27, 2020
69d2b1e
code quality
TevenLeScao Aug 27, 2020
d796eef
double import
TevenLeScao Aug 27, 2020
c175142
made flo estimation more task-agnostic
TevenLeScao Aug 28, 2020
1773dd6
only logging flos if computed
TevenLeScao Aug 28, 2020
4610852
code quality
TevenLeScao Aug 28, 2020
fae5254
unused import
TevenLeScao Aug 28, 2020
6f1b48c
Update src/transformers/trainer.py
TevenLeScao Aug 31, 2020
304ebe8
Update src/transformers/modeling_utils.py
TevenLeScao Aug 31, 2020
8ec3ea6
Sylvain review
TevenLeScao Aug 31, 2020
4becfac
Update src/transformers/modeling_utils.py
TevenLeScao Aug 31, 2020
1aaaa19
Merge branch 'master' into master
TevenLeScao Aug 31, 2020
eb9d328
black
TevenLeScao Aug 31, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 71 additions & 17 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
import inspect
import os
import re
import warnings
from dataclasses import dataclass
from typing import Callable, Dict, List, Optional, Set, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union

import torch
from torch import Tensor, device, dtype, nn
Expand All @@ -45,7 +46,6 @@

logger = logging.get_logger(__name__)


try:
from torch.nn import Identity
except ImportError:
Expand Down Expand Up @@ -91,20 +91,6 @@ class ModuleUtilsMixin:
A few utilities for :obj:`torch.nn.Modules`, to be used as a mixin.
"""

def num_parameters(self, only_trainable: bool = False) -> int:
"""
Get the number of (optionally, trainable) parameters in the model.

Args:
only_trainable (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to return only the number of trainable parameters

Returns:
:obj:`int`: The number of parameters.
"""
params = filter(lambda x: x.requires_grad, self.parameters()) if only_trainable else self.parameters()
return sum(p.numel() for p in params)

@staticmethod
def _hook_rss_memory_pre_forward(module, *args, **kwargs):
try:
Expand Down Expand Up @@ -307,9 +293,77 @@ def _convert_head_mask_to_5d(self, head_mask, num_hidden_layers):
elif head_mask.dim() == 2:
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
assert head_mask.dim() == 5, f"head_mask.dim != 5, instead {head_mask.dim()}"
head_mask = head_mask.to(dtype=self.dtype) # switch to fload if need + fp16 compatibility
head_mask = head_mask.to(dtype=self.dtype) # switch to float if need + fp16 compatibility
return head_mask

def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int:
"""
Get number of (optionally, trainable or non-embeddings) parameters in the module.

Args:
only_trainable (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to return only the number of trainable parameters

exclude_embeddings (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to return only the number of non-embeddings parameters

Returns:
:obj:`int`: The number of parameters.
"""

def parameter_filter(x):
return (x.requires_grad or not only_trainable) and not (
isinstance(x, torch.nn.Embedding) and exclude_embeddings
)

params = filter(parameter_filter, self.parameters()) if only_trainable else self.parameters()
return sum(p.numel() for p in params)

def estimate_tokens(self, input_dict: Dict[str, Union[torch.Tensor, Any]]) -> int:
"""
Helper function to estimate the total number of tokens from the model inputs.

Args:
inputs (:obj:`dict`): The model inputs.

Returns:
:obj:`int`: The total number of tokens.
"""
token_inputs = [tensor for key, tensor in input_dict.items() if "input" in key]
TevenLeScao marked this conversation as resolved.
Show resolved Hide resolved
if token_inputs:
return sum([token_input.numel() for token_input in token_inputs])
else:
warnings.warn(
"Could not estimate the number of tokens of the input, floating-point operations will not be computed"
)
return 0

def floating_point_ops(
self, input_dict: Dict[str, Union[torch.Tensor, Any]], exclude_embeddings: bool = True
) -> int:
"""
Get number of (optionally, non-embeddings) floating-point operations for the forward and backward passes of a
batch with this transformer model. Default approximation neglects the quadratic dependency on the number of
tokens (valid if :obj:`12 * d_model << sequence_length`) as laid out in `this paper <https://arxiv.org/pdf/2001.08361.pdf>`__ section
2.1. Should be overriden for transformers with parameter re-use e.g. Albert or Universal Transformers, or
if doing long-range modeling with very high sequence lengths.

Args:
batch_size (:obj:`int`):
The batch size for the forward pass.

sequence_length (:obj:`int`):
The number of tokens in each line of the batch.

exclude_embeddings (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to count embedding and softmax operations.

Returns:
:obj:`int`: The number of floating-point operations.
"""

return 6 * self.estimate_tokens(input_dict) * self.num_parameters(exclude_embeddings=exclude_embeddings)


class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
r"""
Expand Down
106 changes: 85 additions & 21 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import inspect
import json
import math
import os
import re
Expand Down Expand Up @@ -40,6 +41,8 @@
TrainOutput,
default_compute_objective,
default_hp_space,
distributed_broadcast_scalars,
distributed_concat,
set_seed,
)
from .training_args import TrainingArguments
Expand Down Expand Up @@ -144,7 +147,7 @@ def __iter__(self):
indices = indices[self.rank * self.num_samples : (self.rank + 1) * self.num_samples]
assert (
len(indices) == self.num_samples
), f"Indices length {len(indices)} and and sample number {self.num_samples} mismatched"
), f"Indices length {len(indices)} and sample number {self.num_samples} mismatched"

return iter(indices)

Expand Down Expand Up @@ -239,6 +242,7 @@ def __init__(
"You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
)
self.tb_writer = tb_writer
self.log_history = []
if "prediction_loss_only" in kwargs:
warnings.warn(
"Passing `prediction_loss_only` as a keyword argument is deprecated and won't be possible in a future version. Use `args.prediction_loss_only` instead.",
Expand Down Expand Up @@ -292,6 +296,7 @@ def __init__(

self.global_step = None
self.epoch = None
self.total_flos = None
if self.args.fp16 and _use_native_amp:
self.scaler = torch.cuda.amp.GradScaler()
self.hp_search_backend = None
Expand Down Expand Up @@ -468,7 +473,11 @@ def setup_wandb(self):
logger.info(
'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"'
)
combined_dict = {**self.model.config.to_dict(), **self.args.to_sanitized_dict()}
try:
combined_dict = {**self.model.config.to_dict(), **self.args.to_sanitized_dict()}
except AttributeError:
# in case the model has no config
combined_dict = {**self.args.to_sanitized_dict()}
Comment on lines +479 to +480
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there an example of a model without a configuration?

Copy link
Contributor Author

@TevenLeScao TevenLeScao Sep 3, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yes, it's something @sgugger mentioned as well - when writing for Trainer we assume that the model is a PretrainedModel, but the one in the test doesn't inherit from PretrainedModel which is why I put this in. @julien-c also liked the idea of Trainer being domain-agnostic (eg not only NLP for example) so I figured might as well put this line in since it isn't expensive. I think in the end it's something we might want to think about since there's a lot of references to model.config (for example if training on TPU, which the test doesn't test for)

wandb.init(
project=os.getenv("WANDB_PROJECT", "huggingface"), config=combined_dict, name=self.args.run_name
)
Expand Down Expand Up @@ -638,13 +647,16 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D

self.global_step = 0
self.epoch = 0
self.total_flos = 0
epochs_trained = 0
steps_trained_in_current_epoch = 0
# Check if continuing training from a checkpoint
if model_path is not None:
# set global_step to global_step of last saved checkpoint from model path
try:
self.global_step = int(model_path.split("-")[-1].split(os.path.sep)[0])
self.total_flos = getattr(model.config, "total_flos", 0)

Comment on lines +658 to +659
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wouldn't this fail if the model didn't have a config?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the dummy test model doesn't go through it since it doesn't have a method to calculate flos so I didn't catch it! See above, I think we might have to decide whether we want to assume it has a config or not

epochs_trained = self.global_step // (len(train_dataloader) // self.args.gradient_accumulation_steps)
steps_trained_in_current_epoch = self.global_step % (
len(train_dataloader) // self.args.gradient_accumulation_steps
Expand All @@ -653,9 +665,11 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D
logger.info(" Continuing training from checkpoint, will skip to saved global_step")
logger.info(" Continuing training from epoch %d", epochs_trained)
logger.info(" Continuing training from global step %d", self.global_step)
logger.info(" Continuing training from %d non-embedding floating-point operations", self.total_flos)
logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
except ValueError:
self.global_step = 0
self.total_flos = 0
logger.info(" Starting fine-tuning.")

tr_loss = 0.0
Expand Down Expand Up @@ -689,6 +703,7 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D
continue

tr_loss += self.training_step(model, inputs)
self.total_flos += self.floating_point_ops(inputs)

if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
# last step in epoch but step is always smaller than gradient_accumulation_steps
Expand Down Expand Up @@ -758,7 +773,7 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D
self.save_model(output_dir)

if self.is_world_process_zero():
self._rotate_checkpoints()
self._rotate_checkpoints(use_mtime=True)

if is_torch_tpu_available():
xm.rendezvous("saving_optimizer_states")
Expand Down Expand Up @@ -927,6 +942,13 @@ def log(self, logs: Dict[str, float], iterator: Optional[tqdm] = None) -> None:

if self.epoch is not None:
logs["epoch"] = self.epoch
if self.total_flos is not None:
if self.args.local_rank != -1:
total_flos = distributed_broadcast_scalars([self.total_flos]).sum().item()
else:
total_flos = self.total_flos
if total_flos > 0:
logs["total_flos"] = self.total_flos
if self.global_step is None:
# when logging evaluation metrics without training
self.global_step = 0
Expand Down Expand Up @@ -954,6 +976,8 @@ def log(self, logs: Dict[str, float], iterator: Optional[tqdm] = None) -> None:
if experiment is not None:
experiment._log_metrics(logs, step=self.global_step, epoch=self.epoch, framework="transformers")
output = {**logs, **{"step": self.global_step}}
if self.is_world_process_zero():
self.log_history.append(output)
if iterator is not None:
iterator.write(output)
else:
Expand Down Expand Up @@ -1092,13 +1116,17 @@ def _save_tpu(self, output_dir: Optional[str] = None):
if xm.is_master_ordinal():
os.makedirs(output_dir, exist_ok=True)
torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
json.dump(
self.log_history, open(os.path.join(output_dir, "log_history.json"), "w"), indent=2, ensure_ascii=False
)

# Save a trained model and configuration using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`
if not isinstance(self.model, PreTrainedModel):
raise ValueError("Trainer.model appears to not be a PreTrainedModel")

xm.rendezvous("saving_checkpoint")
self._store_flos()
self.model.save_pretrained(output_dir)
if self.tokenizer is not None:
self.tokenizer.save_pretrained(output_dir)
Expand All @@ -1111,12 +1139,26 @@ def _save(self, output_dir: Optional[str] = None):
# They can then be reloaded using `from_pretrained()`
if not isinstance(self.model, PreTrainedModel):
raise ValueError("Trainer.model appears to not be a PreTrainedModel")
self._store_flos()
self.model.save_pretrained(output_dir)
if self.tokenizer is not None:
self.tokenizer.save_pretrained(output_dir)

# Good practice: save your training arguments together with the trained model
torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
json.dump(
self.log_history, open(os.path.join(output_dir, "log_history.json"), "w"), indent=2, ensure_ascii=False
)

def _store_flos(self):
# Storing the number of floating-point operations that went into the model
if self.total_flos is not None:
if self.args.local_rank != -1:
total_flos = distributed_broadcast_scalars([self.total_flos]).sum().item()
else:
total_flos = self.total_flos
if total_flos > 0:
self.model.config.total_flos = total_flos

def _sorted_checkpoints(self, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False) -> List[str]:
ordering_and_checkpoint_path = []
Expand Down Expand Up @@ -1248,13 +1290,11 @@ def prediction_loop(
self._past = None

disable_tqdm = not self.is_local_process_zero() or self.args.disable_tqdm
samples_count = 0
for inputs in tqdm(dataloader, desc=description, disable=disable_tqdm):
loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only)
batch_size = inputs[list(inputs.keys())[0]].shape[0]
samples_count += batch_size
if loss is not None:
eval_losses.append(loss * batch_size)
eval_losses.extend([loss] * batch_size)
if logits is not None:
preds = logits if preds is None else torch.cat((preds, logits), dim=0)
if labels is not None:
Expand All @@ -1267,9 +1307,9 @@ def prediction_loop(
if self.args.local_rank != -1:
# In distributed mode, concatenate all results from all nodes:
if preds is not None:
preds = self.distributed_concat(preds, num_total_examples=self.num_examples(dataloader))
preds = distributed_concat(preds, num_total_examples=self.num_examples(dataloader))
if label_ids is not None:
label_ids = self.distributed_concat(label_ids, num_total_examples=self.num_examples(dataloader))
label_ids = distributed_concat(label_ids, num_total_examples=self.num_examples(dataloader))
elif is_torch_tpu_available():
# tpu-comment: Get all predictions and labels from all worker shards of eval dataset
if preds is not None:
Expand All @@ -1288,7 +1328,14 @@ def prediction_loop(
else:
metrics = {}
if len(eval_losses) > 0:
metrics["eval_loss"] = np.sum(eval_losses) / samples_count
if self.args.local_rank != -1:
metrics["eval_loss"] = (
distributed_broadcast_scalars(eval_losses, num_total_examples=self.num_examples(dataloader))
.mean()
.item()
)
else:
metrics["eval_loss"] = np.mean(eval_losses)

# Prefix all keys with eval_
for key in list(metrics.keys()):
Expand All @@ -1297,18 +1344,6 @@ def prediction_loop(

return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics)

def distributed_concat(self, tensor: torch.Tensor, num_total_examples: int) -> torch.Tensor:
assert self.args.local_rank != -1

output_tensors = [tensor.clone() for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather(output_tensors, tensor)

concat = torch.cat(output_tensors, dim=0)

# truncate the dummy elements added by SequentialDistributedSampler
output = concat[:num_total_examples]
return output

def prediction_step(
self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], prediction_loss_only: bool
) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
Expand Down Expand Up @@ -1354,3 +1389,32 @@ def prediction_step(
if labels is not None:
labels = labels.detach()
return (loss, logits.detach(), labels)

def floating_point_ops(self, inputs: Dict[str, Union[torch.Tensor, Any]]):
"""
For models that inherit from :class:`~transformers.PretrainedModel`, uses
that method to compute the number of floating point operations for every backward + forward pass. If using
another model, either implement such a method in the model or subclass and override this method.

Args:
model (:obj:`nn.Module`):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we use self.model?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, changed it, allows us to save a few lines in the main method too

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can remove the docstring as well

The model to evaluate.
inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
The inputs and targets of the model.

Returns:
:obj:`int`: The number of floating-point operations.
"""

if isinstance(self.model, torch.nn.DataParallel) or isinstance(
self.model, torch.nn.parallel.DistributedDataParallel
):
model = self.model.module
else:
model = self.model

if hasattr(model, "floating_point_ops"):
return model.floating_point_ops(inputs)

else:
return 0
Loading