-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Enable AMP with Apex #3866
Enable AMP with Apex #3866
Changes from 5 commits
f49ec6b
164a6e7
3d3abaf
eb69fde
f765699
e085648
609352a
e3a29c6
7f2fce0
5941b95
5fc9654
3b95273
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,6 +7,7 @@ | |
import traceback | ||
from typing import Dict, List, Optional, Tuple, Union, Any | ||
|
||
from apex import amp | ||
import torch | ||
import torch.distributed as dist | ||
import torch.optim.lr_scheduler | ||
|
@@ -66,6 +67,7 @@ def __init__( | |
local_rank: int = 0, | ||
world_size: int = 1, | ||
num_gradient_accumulation_steps: int = 1, | ||
opt_level: Optional[str] = None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add this to the docstring? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done, added a comment below. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What about There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm, they also call this arg There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, cause amp is the main thing in apex. As far as I understand, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, it's actually "optimization level". Still, I'd do There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would lean towards sticking with |
||
) -> None: | ||
""" | ||
A trainer for doing supervised learning. It just takes a labeled dataset | ||
|
@@ -185,6 +187,11 @@ def __init__( | |
Gradients are accumulated for the given number of steps before doing an optimizer step. This can | ||
be useful to accommodate batches that are larger than the RAM size. Refer Thomas Wolf's | ||
[post](https://tinyurl.com/y5mv44fw) for details on Gradient Accumulation. | ||
opt_level : `str`, optional, (default = `None`) | ||
Each opt_level establishes a set of properties that govern Amp’s implementation of pure or mixed | ||
precision training. Must be a choice of `"O0"`, `"O1"`, `"O2"`, or `"O3"`. | ||
See the Apex [documentation](https://nvidia.github.io/apex/amp.html#opt-levels-and-properties) for | ||
more details. If `None`, Amp is not used. Defaults to `None`. | ||
Comment on lines
+193
to
+197
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @matt-gardner How's this? |
||
""" | ||
super().__init__(serialization_dir, cuda_device, distributed, local_rank, world_size) | ||
|
||
|
@@ -267,6 +274,11 @@ def __init__( | |
if histogram_interval is not None: | ||
self._tensorboard.enable_activation_logging(self.model) | ||
|
||
# Enable automatic mixed precision training with NVIDIA Apex. | ||
self._opt_level = opt_level | ||
if self._opt_level is not None: | ||
self.model, self.optimizer = amp.initialize(self.model, self.optimizer, opt_level=self._opt_level) | ||
|
||
# Using `DistributedDataParallel`(ddp) brings in a quirk wrt AllenNLP's `Model` interface and its | ||
# usage. A `Model` object is wrapped by `ddp`, but assigning the wrapped model to `self.model` | ||
# will break the usages such as `Model.get_regularization_penalty`, `Model.get_metrics`, etc. | ||
|
@@ -282,7 +294,18 @@ def __init__( | |
self._pytorch_model = self.model | ||
|
||
def rescale_gradients(self) -> Optional[float]: | ||
return training_util.rescale_gradients(self.model, self._grad_norm) | ||
""" | ||
Performs gradient rescaling. Is a no-op if gradient rescaling is not enabled. | ||
""" | ||
if self._grad_norm: | ||
if self._opt_level is not None: | ||
# See: https://nvidia.github.io/apex/advanced.html#gradient-clipping | ||
parameters_to_clip = [p for p in amp.master_params(self.optimizer) if p.grad is not None] | ||
else: | ||
parameters_to_clip = [p for p in self.model.parameters() if p.grad is not None] | ||
return training_util.sparse_clip_norm(parameters_to_clip, self._grad_norm) | ||
else: | ||
return None | ||
|
||
def batch_loss(self, batch: TensorDict, for_training: bool) -> torch.Tensor: | ||
""" | ||
|
@@ -384,7 +407,11 @@ def _train_epoch(self, epoch: int) -> Dict[str, float]: | |
if torch.isnan(loss): | ||
raise ValueError("nan loss encountered") | ||
loss = loss / len(batch_group) | ||
loss.backward() | ||
if self._opt_level is not None: | ||
with amp.scale_loss(loss, self.optimizer) as scaled_loss: | ||
scaled_loss.backward() | ||
else: | ||
loss.backward() | ||
train_loss += loss.item() | ||
|
||
batch_grad_norm = self.rescale_gradients() | ||
|
@@ -829,6 +856,7 @@ def from_partial_objects( | |
distributed: bool = None, | ||
world_size: int = 1, | ||
num_gradient_accumulation_steps: int = 1, | ||
opt_level: Optional[str] = None, | ||
no_grad: List[str] = None, | ||
optimizer: Lazy[Optimizer] = None, | ||
learning_rate_scheduler: Lazy[LearningRateScheduler] = None, | ||
|
@@ -910,4 +938,5 @@ def from_partial_objects( | |
local_rank=local_rank, | ||
world_size=world_size, | ||
num_gradient_accumulation_steps=num_gradient_accumulation_steps, | ||
opt_level=opt_level, | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,6 +25,9 @@ matplotlib>=2.2.3 | |
# Required to run sanic tests | ||
aiohttp | ||
|
||
# Required for automatic mixed precision (AMP) training | ||
git+https://github.com/NVIDIA/apex.git@master | ||
|
||
Comment on lines
+28
to
+30
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Okay, as discussed, I have added apex as a dev requirement |
||
#### DOC-RELATED PACKAGES #### | ||
|
||
# YAML manipulation | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would suggest this:
and then below, check that
amp is not None
and raise aConfigurationError
if it is and the user has passed an opt level. Does that make sense?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.