diff --git a/allennlp/tests/training/trainer_test.py b/allennlp/tests/training/trainer_test.py index 655f0e81ef1..8cd31627435 100644 --- a/allennlp/tests/training/trainer_test.py +++ b/allennlp/tests/training/trainer_test.py @@ -7,6 +7,11 @@ import math import pytest + +try: + from apex import amp +except ImportError: + amp = None import torch from torch.utils.data import DataLoader @@ -121,6 +126,21 @@ def test_passing_trainer_multiple_gpus_raises_error(self): self.model, self.optimizer, self.data_loader, num_epochs=2, cuda_device=[0, 1], ) + @pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device registered.") + @pytest.mark.skipif(amp is None, reason="Apex is not installed.") + def test_trainer_can_run_amp(self): + + self.model.cuda() + trainer = Trainer( + self.model, + self.optimizer, + self.data_loader, + num_epochs=2, + cuda_device=0, + opt_level="O1", + ) + _ = trainer.train() + def test_trainer_can_resume_training(self): trainer = Trainer( self.model, diff --git a/allennlp/training/trainer.py b/allennlp/training/trainer.py index eb231d0b935..239392db416 100644 --- a/allennlp/training/trainer.py +++ b/allennlp/training/trainer.py @@ -7,6 +7,10 @@ import traceback from typing import Dict, List, Optional, Tuple, Union, Any +try: + from apex import amp +except ImportError: + amp = None import torch import torch.distributed as dist import torch.optim.lr_scheduler @@ -66,6 +70,7 @@ def __init__( local_rank: int = 0, world_size: int = 1, num_gradient_accumulation_steps: int = 1, + opt_level: Optional[str] = None, ) -> None: """ A trainer for doing supervised learning. It just takes a labeled dataset @@ -185,6 +190,11 @@ def __init__( Gradients are accumulated for the given number of steps before doing an optimizer step. This can be useful to accommodate batches that are larger than the RAM size. Refer Thomas Wolf's [post](https://tinyurl.com/y5mv44fw) for details on Gradient Accumulation. + opt_level : `str`, optional, (default = `None`) + Each opt_level establishes a set of properties that govern Amp’s implementation of pure or mixed + precision training. Must be a choice of `"O0"`, `"O1"`, `"O2"`, or `"O3"`. + See the Apex [documentation](https://nvidia.github.io/apex/amp.html#opt-levels-and-properties) for + more details. If `None`, Amp is not used. Defaults to `None`. """ super().__init__(serialization_dir, cuda_device, distributed, local_rank, world_size) @@ -267,6 +277,21 @@ def __init__( if histogram_interval is not None: self._tensorboard.enable_activation_logging(self.model) + # Enable automatic mixed precision training with NVIDIA Apex. + self._opt_level = opt_level + if self._opt_level is not None: + if amp is None: + raise ConfigurationError( + ( + "Apex not installed but opt_level was provided. Please install NVIDIA's Apex to enable" + " automatic mixed precision (AMP) training. See: https://github.com/NVIDIA/apex." + ) + ) + + self.model, self.optimizer = amp.initialize( + self.model, self.optimizer, opt_level=self._opt_level + ) + # Using `DistributedDataParallel`(ddp) brings in a quirk wrt AllenNLP's `Model` interface and its # usage. A `Model` object is wrapped by `ddp`, but assigning the wrapped model to `self.model` # will break the usages such as `Model.get_regularization_penalty`, `Model.get_metrics`, etc. @@ -282,7 +307,20 @@ def __init__( self._pytorch_model = self.model def rescale_gradients(self) -> Optional[float]: - return training_util.rescale_gradients(self.model, self._grad_norm) + """ + Performs gradient rescaling. Is a no-op if gradient rescaling is not enabled. + """ + if self._grad_norm: + if self._opt_level is not None: + # See: https://nvidia.github.io/apex/advanced.html#gradient-clipping + parameters_to_clip = [ + p for p in amp.master_params(self.optimizer) if p.grad is not None + ] + else: + parameters_to_clip = [p for p in self.model.parameters() if p.grad is not None] + return training_util.sparse_clip_norm(parameters_to_clip, self._grad_norm) + else: + return None def batch_loss(self, batch: TensorDict, for_training: bool) -> torch.Tensor: """ @@ -384,7 +422,11 @@ def _train_epoch(self, epoch: int) -> Dict[str, float]: if torch.isnan(loss): raise ValueError("nan loss encountered") loss = loss / len(batch_group) - loss.backward() + if self._opt_level is not None: + with amp.scale_loss(loss, self.optimizer) as scaled_loss: + scaled_loss.backward() + else: + loss.backward() train_loss += loss.item() batch_grad_norm = self.rescale_gradients() @@ -829,6 +871,7 @@ def from_partial_objects( distributed: bool = None, world_size: int = 1, num_gradient_accumulation_steps: int = 1, + opt_level: Optional[str] = None, no_grad: List[str] = None, optimizer: Lazy[Optimizer] = None, learning_rate_scheduler: Lazy[LearningRateScheduler] = None, @@ -910,4 +953,5 @@ def from_partial_objects( local_rank=local_rank, world_size=world_size, num_gradient_accumulation_steps=num_gradient_accumulation_steps, + opt_level=opt_level, ) diff --git a/dev-requirements.txt b/dev-requirements.txt index 22983385e6b..cc444fb118f 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -25,6 +25,9 @@ matplotlib>=2.2.3 # Required to run sanic tests aiohttp +# Required for automatic mixed precision (AMP) training +git+https://github.com/NVIDIA/apex.git@master + #### DOC-RELATED PACKAGES #### # YAML manipulation