Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Enable AMP with Apex #3866

Merged
12 commits merged into from
Mar 2, 2020
Merged

Enable AMP with Apex #3866

12 commits merged into from
Mar 2, 2020

Conversation

JohnGiorgi
Copy link
Contributor

@JohnGiorgi JohnGiorgi commented Feb 27, 2020

Overview

This PR enables automatic mixed precision training with NVIDIA's Apex which has been much discussed in #2149 and is one of the action items on the roadmap.

The modifications were very straightforward. Here is the high-level:

  1. Add argument opt_level to Trainer. This coincides with the opt levels of Apex.
  2. In the constructor of Trainer, if opt_level is not None, wrap self.model and self.optimizer with amp.initialize.
  3. rescale_gradients needs to be modified slightly when using Apex. See here for more info.
  4. In _train_epoch, if opt_level is not None, call loss.backward() in the amp.scale_loss context manager.

I got all of this straight from the Apex docs.

Benchmark

I tested it on my model and found a ~40% speedup in time per epoch during training. It also allowed me to train with a 40% larger mini-batch. The majority of my models parameters exist in a pre-trained transformer (distilroberta-base from Transformers).

Simple benchmark on a 16GB-V100 using a random 1K documents from wikitext-103:

opt_level Batch Size Time per epoch
"O0" 12* ~58s
"O1" 12 ~35s
"O1" 20* ~39s

*Maximum batch size that fits in memory at this opt_level.

Obviously, mileage will vary but this gives a flavour of the magnitude of speedup you can get with Apex.

Things I need help with

  • The way I have written this, there will be an ImportError if the user does not have amp installed. This is probably not what we want.
  • I just shoehorned in the required modification to rescale_gradients. There is likely a cleaner way to do this.

@@ -66,6 +68,7 @@ def __init__(
local_rank: int = 0,
world_size: int = 1,
num_gradient_accumulation_steps: int = 1,
opt_level: Optional[str] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add this to the docstring?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, added a comment below.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about amp_level?

Copy link
Contributor Author

@JohnGiorgi JohnGiorgi Feb 28, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, they also call this arg opt_level in apex (see here). Is there a good argument for calling it something different in AllenNLP?

Copy link
Contributor

@bryant1410 bryant1410 Feb 28, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, cause amp is the main thing in apex. As far as I understand, opt_level means "option level". In AllenNLP there are many options. So, amp_level would mean "automatic mixed precision level".

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, it's actually "optimization level". Still, I'd do amp_opt_level, cause there are many different options in Trainer and it can get confusing. What do you think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would lean towards sticking with opt_level here, or perhaps amp_opt_level. Switching the name of a parameter makes it harder for someone familiar with apex to know what's going on here.

@@ -32,6 +32,8 @@
from allennlp.training.tensorboard_writer import TensorboardWriter
from allennlp.training.trainer_base import TrainerBase

from apex import amp
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm ok with just adding this to our list of dependencies, unless it's super heavy. Though it should be above, in the same block as the torch imports.

Also, can you add a test with a not-None opt_level, showing that this works?

Copy link
Contributor Author

@JohnGiorgi JohnGiorgi Feb 27, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, moved apex import to torch block. Put it at the top because its convention for them to be alphabetically sorted (I think?)

Here is the install process for apex. I will leave it up to you as to whether this is too heavy or not!

Sure, will work on adding a test either tonight or tomorrow.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we can add it to our setup.py, and if someone wants to install their own optimized version, that's fine too? The point is we want to not crash when someone just does pip install allennlp.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, that seems perfectly reasonable to me.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, I need a little time to get my head around how the testing works in AllenNLP. If it is very easy for one of the AllenNLP devs to add a test for when opt_level is not None I could also pass the buck so this can get merged sooner.

return training_util.sparse_clip_norm(parameters_to_clip, self._grad_norm)
return None
else:
return training_util.rescale_gradients(self.model, self._grad_norm)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another option here is to add a use_amp flag to rescale_gradients. Oh, except that then also requires passing in the optimizer... Yeah, I think what you have here is good enough.

Copy link
Contributor Author

@JohnGiorgi JohnGiorgi Feb 27, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes exactly my thinking. Otherwise, we could have a second util function named something like rescale_gradients_amp which accepts an optimizer.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think adding that function would be overkill. If we want to refactor things, we should have rescale_gradients take a list of parameters. But then we can just remove it and call sparse_clip_norm directly. I'm somewhat in favor of that option, as it would simplify the above code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I refactored Trainer.rescale_gradients to use training_utils.sparse_clip_norm directly. Otherwise the function is the same.

Comment on lines +191 to +195
opt_level : `str`, optional, (default = `None`)
Each opt_level establishes a set of properties that govern Amp’s implementation of pure or mixed
precision training. Must be a choice of `"O0"`, `"O1"`, `"O2"`, or `"O3"`.
See the Apex [documentation](https://nvidia.github.io/apex/amp.html#opt-levels-and-properties) for
more details. If `None`, Amp is not used. Defaults to `None`.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@matt-gardner How's this?

Copy link
Contributor

@DeNeutoy DeNeutoy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JohnGiorgi this PR is great, thanks a lot!

Firstly, you asked how to run the tests. You can do that by:

pip install -r dev-requirements.txt

  • Unit tests: pytest allennlp/tests/path/to/test . -v -s flags make it verbose.
  • Auto formatter: black allennlp reformats python code automatically - just commit the changes it makes.
  • Lint: flake8 allennlp
  • Type Checking: bash scripts/mypy.sh (we have some specific mypy settings, so we have a script which runs it automatically.)

It would be good if you could add a test like this, which just checks that the trainer can run with one of the non-None optimisation levels. You can also add a decorator to the test which checks if apex is installed (see discussion below).

Some things we might need to resolve:

  • Apex does not offer a pip installable package, which means either we need to install from github, e.g pip install git+https://github.com/NVIDIA/apex.git@master in setup.py or we need to make apex optional. I am in favour of the second option. This would mean moving the apex import inside the check for opt_level being none, and adding the github install to dev-requirements.txt, so that we can at least run a test when it is present (we install dev-requirements when testing, but actually using allennlp must not be dependent on them).

  • Checkpointing the loss scaling seems to require some additional changes. I don't want this to block the addition of apex to allennlp, but we should at least open an issue for it afterward if we can't properly restore a model without it.

Comment on lines +28 to +30
# Required for automatic mixed precision (AMP) training
git+https://github.com/NVIDIA/apex.git@master

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, as discussed, I have added apex as a dev requirement

Comment on lines 123 to 137

@pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device registered.")
@pytest.mark.skipif(importlib.util.find_spec("apex") is None, reason="Apex is not installed.")
def test_trainer_can_run_amp(self):
from apex import amp
self.model.cuda()
trainer = Trainer(
self.model, self.optimizer, self.data_loader, num_epochs=2, cuda_device=0, opt_level="O1"
)
metrics = trainer.train()
assert "peak_cpu_memory_MB" in metrics
assert isinstance(metrics["peak_cpu_memory_MB"], float)
assert metrics["peak_cpu_memory_MB"] > 0
assert "peak_gpu_0_memory_MB" in metrics
assert isinstance(metrics["peak_gpu_0_memory_MB"], int)
Copy link
Contributor Author

@JohnGiorgi JohnGiorgi Feb 28, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a test to make sure the model runs with amp. There are two skipif checks, one to make sure a cuda device is available and one to check if apex is importable (I wasn't sure if it made more sense to combine the checks in one decorator or have one check per decorator, went with the latter).

@JohnGiorgi
Copy link
Contributor Author

JohnGiorgi commented Feb 28, 2020

@DeNeutoy Thanks a lot for the guidance. That made it easy to add a test. Also added apex to the dev requirements.

One thing: If I put from apex import amp into the if self._opt_level is not None check, I will get an import error in rescale_gradients when I try to call amp.master_params. Is there a way to import apex within the if statement such that I can reference it outside the constructor's scope?

Good catch on the checkpointing thing. I guess it is up to AllenNLP devs, I could open another PR after this is merged to get checkpointing with an amp-enabled model working OR try to add it to this PR.

Copy link
Contributor

@DeNeutoy DeNeutoy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fine to leave the checkpointing, but open an issue/follow up PR - LGTM with the suggested change for managing the optional dependency above 👍

@@ -7,6 +7,7 @@
import traceback
from typing import Dict, List, Optional, Tuple, Union, Any

from apex import amp
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest this:

try:
	from apex import amp
except ImportError:
	amp = None

and then below, check that amp is not None and raise a ConfigurationError if it is and the user has passed an opt level. Does that make sense?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@bryant1410
Copy link
Contributor

bryant1410 commented Feb 28, 2020

As mentioned in #3851, there are some things around in the codebase (e.g., some metrics) that IMHO wouldn't play well with FP16 and I have had problems in the past, such as having some + 1e-13, but I guess those will be fixed later when people detect and report them, I guess when they have radically different results with and without some specific AMP opt_level.

Comment on lines 283 to 288
if amp is not None:
raise ConfigurationError(
("Apex not installed but opt_level was provided. Please install NVIDIA's Apex to enable"
" automatic mixed precision (AMP) training. See: https://github.com/NVIDIA/apex.")
)

Copy link
Contributor Author

@JohnGiorgi JohnGiorgi Feb 28, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed! -- Whoops, caught a bug. It is now if amp is None.

Copy link
Contributor

@DeNeutoy DeNeutoy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JohnGiorgi One last thing and then this is good to go, thanks a lot!

opt_level="O1",
)
metrics = trainer.train()
assert "peak_cpu_memory_MB" in metrics
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can delete these asserts, as they aren't related to this test.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

@@ -121,6 +122,26 @@ def test_passing_trainer_multiple_gpus_raises_error(self):
self.model, self.optimizer, self.data_loader, num_epochs=2, cuda_device=[0, 1],
)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device registered.")
@pytest.mark.skipif(importlib.util.find_spec("apex") is None, reason="Apex is not installed.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This importlib thing is doing something funny to our CI, can you just change it to how we imported it in the trainer itself?

try:
	import apex
accept:
	apex = None

...

@pytest.mark.skipif(apex is None, ...)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

@ghost ghost merged commit 44ed34c into allenai:master Mar 2, 2020
@bryant1410
Copy link
Contributor

I saw the GPU tests failed after this PR. Maybe there's something to review?

Can the bulldozer run the GPU tests and report if they fail?

@schmmd schmmd added this to the Performance milestone May 20, 2020
This pull request was closed.
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants