Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Enable AMP with Apex #3866

Merged
12 commits merged into from
Mar 2, 2020
33 changes: 31 additions & 2 deletions allennlp/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import traceback
from typing import Dict, List, Optional, Tuple, Union, Any

from apex import amp
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest this:

try:
	from apex import amp
except ImportError:
	amp = None

and then below, check that amp is not None and raise a ConfigurationError if it is and the user has passed an opt level. Does that make sense?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

import torch
import torch.distributed as dist
import torch.optim.lr_scheduler
Expand Down Expand Up @@ -66,6 +67,7 @@ def __init__(
local_rank: int = 0,
world_size: int = 1,
num_gradient_accumulation_steps: int = 1,
opt_level: Optional[str] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add this to the docstring?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, added a comment below.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about amp_level?

Copy link
Contributor Author

@JohnGiorgi JohnGiorgi Feb 28, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, they also call this arg opt_level in apex (see here). Is there a good argument for calling it something different in AllenNLP?

Copy link
Contributor

@bryant1410 bryant1410 Feb 28, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, cause amp is the main thing in apex. As far as I understand, opt_level means "option level". In AllenNLP there are many options. So, amp_level would mean "automatic mixed precision level".

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, it's actually "optimization level". Still, I'd do amp_opt_level, cause there are many different options in Trainer and it can get confusing. What do you think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would lean towards sticking with opt_level here, or perhaps amp_opt_level. Switching the name of a parameter makes it harder for someone familiar with apex to know what's going on here.

) -> None:
"""
A trainer for doing supervised learning. It just takes a labeled dataset
Expand Down Expand Up @@ -185,6 +187,11 @@ def __init__(
Gradients are accumulated for the given number of steps before doing an optimizer step. This can
be useful to accommodate batches that are larger than the RAM size. Refer Thomas Wolf's
[post](https://tinyurl.com/y5mv44fw) for details on Gradient Accumulation.
opt_level : `str`, optional, (default = `None`)
Each opt_level establishes a set of properties that govern Amp’s implementation of pure or mixed
precision training. Must be a choice of `"O0"`, `"O1"`, `"O2"`, or `"O3"`.
See the Apex [documentation](https://nvidia.github.io/apex/amp.html#opt-levels-and-properties) for
more details. If `None`, Amp is not used. Defaults to `None`.
Comment on lines +193 to +197
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@matt-gardner How's this?

"""
super().__init__(serialization_dir, cuda_device, distributed, local_rank, world_size)

Expand Down Expand Up @@ -267,6 +274,11 @@ def __init__(
if histogram_interval is not None:
self._tensorboard.enable_activation_logging(self.model)

# Enable automatic mixed precision training with NVIDIA Apex.
self._opt_level = opt_level
if self._opt_level is not None:
self.model, self.optimizer = amp.initialize(self.model, self.optimizer, opt_level=self._opt_level)

# Using `DistributedDataParallel`(ddp) brings in a quirk wrt AllenNLP's `Model` interface and its
# usage. A `Model` object is wrapped by `ddp`, but assigning the wrapped model to `self.model`
# will break the usages such as `Model.get_regularization_penalty`, `Model.get_metrics`, etc.
Expand All @@ -282,7 +294,18 @@ def __init__(
self._pytorch_model = self.model

def rescale_gradients(self) -> Optional[float]:
return training_util.rescale_gradients(self.model, self._grad_norm)
"""
Performs gradient rescaling. Is a no-op if gradient rescaling is not enabled.
"""
if self._grad_norm:
if self._opt_level is not None:
# See: https://nvidia.github.io/apex/advanced.html#gradient-clipping
parameters_to_clip = [p for p in amp.master_params(self.optimizer) if p.grad is not None]
else:
parameters_to_clip = [p for p in self.model.parameters() if p.grad is not None]
return training_util.sparse_clip_norm(parameters_to_clip, self._grad_norm)
else:
return None

def batch_loss(self, batch: TensorDict, for_training: bool) -> torch.Tensor:
"""
Expand Down Expand Up @@ -384,7 +407,11 @@ def _train_epoch(self, epoch: int) -> Dict[str, float]:
if torch.isnan(loss):
raise ValueError("nan loss encountered")
loss = loss / len(batch_group)
loss.backward()
if self._opt_level is not None:
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
train_loss += loss.item()

batch_grad_norm = self.rescale_gradients()
Expand Down Expand Up @@ -829,6 +856,7 @@ def from_partial_objects(
distributed: bool = None,
world_size: int = 1,
num_gradient_accumulation_steps: int = 1,
opt_level: Optional[str] = None,
no_grad: List[str] = None,
optimizer: Lazy[Optimizer] = None,
learning_rate_scheduler: Lazy[LearningRateScheduler] = None,
Expand Down Expand Up @@ -910,4 +938,5 @@ def from_partial_objects(
local_rank=local_rank,
world_size=world_size,
num_gradient_accumulation_steps=num_gradient_accumulation_steps,
opt_level=opt_level,
)
3 changes: 3 additions & 0 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ matplotlib>=2.2.3
# Required to run sanic tests
aiohttp

# Required for automatic mixed precision (AMP) training
git+https://github.com/NVIDIA/apex.git@master

Comment on lines +28 to +30
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, as discussed, I have added apex as a dev requirement

#### DOC-RELATED PACKAGES ####

# YAML manipulation
Expand Down