Skip to content
This repository has been archived by the owner on Jul 7, 2023. It is now read-only.

Large training batches on limited GPU hardware #754

Merged
merged 14 commits into from
Jun 5, 2018
Merged

Large training batches on limited GPU hardware #754

merged 14 commits into from
Jun 5, 2018

Conversation

fstahlberg
Copy link
Contributor

@fstahlberg fstahlberg commented Apr 29, 2018

This PR adds a LargebatchAdam optimizer, which accumulates gradients over n batches and applies the Adam learning rule every n batches on the accumulated gradients. This makes it possible to arbitrarily increase the effective batch size / number of GPUs at cost of more training iterations. This technique is useful if the number of physical GPUs is limited or the GPU memory does not allow to increase the batch size any further. Large batch / multi-GPU training is often important for Transformer training as reported in #444 . See Saunders et al., 2018 for more details.

See transformer_base_fake_gpu8 hparams set as an example.

This is a new version of the PR #750 which fixes issues with the Google CLA.

Copy link
Contributor

@rsepassi rsepassi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @fstahlberg! Very cool optimizer.

I think something of this size also warrants a test.

@@ -1081,6 +1081,17 @@ def transformer_base_single_gpu():
return hparams


@registry.register_hparams
def transformer_base_fake_gpu8():
"""HParams for simulating 8 GPU transformer base model training
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update docstring:
HParams for simulating 8 GPUs with LargebatchAdam optimizer.

# Dependency imports

import tensorflow as tf
from tensorflow.python.eager import context
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use tf.contrib.eager.in_eager_mode()


import tensorflow as tf
from tensorflow.python.eager import context
from tensorflow.python.framework import ops
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of these specific imports, can you switch to accessing through tf?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That was a legacy from using the AdamOptimizer as blueprint. Fixed now.



class LargebatchAdamOptimizer(tf.contrib.opt.LazyAdamOptimizer):
"""Adam with delayed SGD updates."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adam with SGD updates every n steps with accumulated gradients.

self._n = n # Call Adam optimizer every n batches with accumulated grads
self._n_t = None # n as tensor

def _create_slots(self, var_list):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First call super, then just add the gradient accumulators.

use_locking=self._use_locking)
return control_flow_ops.group(update_beta1, update_beta2)
maybe_update_beta = tf.cond(tf.equal(iter_, 0),
lambda: update_beta_op(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update_beta_op

return control_flow_ops.group(update_beta1, update_beta2)
maybe_update_beta = tf.cond(tf.equal(iter_, 0),
lambda: update_beta_op(),
lambda: tf.no_op())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tf.no_op

fake_gpu_multiplier = tf.constant(hparams.fake_gpu_multiplier,
dtype=tf.float32)
step = step / fake_gpu_multiplier
tf.logging.info("Scaling down learning rate decay by "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Divided global step by fake_gpu_multiplier=%d

"""
hparams = transformer_base()
hparams.optimizer = "LargebatchAdam"
hparams.add_hparam("fake_gpu_multiplier", 8)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add this to common_hparams.py basic_params1 instead with a default value of None

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. I renamed it to optimizer_multistep_accumulate_steps for consistency with the optimizer_ada[m|factor]_* options


See [Saunders et al., 2018](https://arxiv.org/abs/1805.00456) for details.
"""

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like to find a different name for the file and the class.

How about multistep_optimizer and MultistepAdamOptimizer?

@vince62s
Copy link
Contributor

vince62s commented May 6, 2018

@fstahlberg I confirm this is great, we did it and use it in opennmt-py works fine.
However out of curiosity did you try it with 8 "accumulation" ?
I was able to work with 4 without any problem (for a bs of 4096) but I was not able to fit on a gtx 1080ti
thanks.

@fstahlberg
Copy link
Contributor Author

@vince62s Yes, I tried it with "delay factor" 8 - there should be no difference regarding GPU memory between 4 and 8. Did you use the same code?

@rsepassi thanks for the review, I'll work on it in the next few days.

@rsepassi
Copy link
Contributor

rsepassi commented May 7, 2018 via email

Copy link
Contributor Author

@fstahlberg fstahlberg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pushed all the code changes. tests will follow soon-ish.

It is now called MultistepAdamOptimizer and optimizer_multistep_accumulate_steps.

"""
hparams = transformer_base()
hparams.optimizer = "LargebatchAdam"
hparams.add_hparam("fake_gpu_multiplier", 8)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. I renamed it to optimizer_multistep_accumulate_steps for consistency with the optimizer_ada[m|factor]_* options


import tensorflow as tf
from tensorflow.python.eager import context
from tensorflow.python.framework import ops
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That was a legacy from using the AdamOptimizer as blueprint. Fixed now.

super(LargebatchAdamOptimizer, self)._apply_sparse_shared, grad, var,
indices, scatter_add)

def _apply_sparse(self, grad, var):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure that I can do that? The optimizer works even for sparse tensors, just not as efficient as it could be as I'm simply converting to dense and use _apply_dense

@fstahlberg
Copy link
Contributor Author

@rsepassi I've added a unit test, but I had to put an awkward version check to pass all tests since it doesn't work with TF < 1.6.

Copy link
Contributor

@rsepassi rsepassi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very sorry for the long delay. NIPS deadline had us quite busy. Looks good though I think the test should be modified.

"""Adam with SGD updates every n steps with accumulated gradients."""

def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8,
use_locking=False, name="Adam", n=2):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's make the default 1

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

"""Adjust global step if a multi-step optimizer is used."""
step = tf.to_float(tf.train.get_or_create_global_step())
multiplier = hparams.optimizer_multistep_accumulate_steps
if multiplier is not None and multiplier > 1:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if multiplier:

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

step = tf.to_float(tf.train.get_or_create_global_step())
multiplier = hparams.optimizer_multistep_accumulate_steps
if multiplier is not None and multiplier > 1:
step = step / tf.constant(multiplier, dtype=tf.float32)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tf.to_float(step) / tf.to_float(multiplier)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done (step is already a float tensor)

@@ -0,0 +1,123 @@
# coding=utf-8
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for adding this test, but I think the test should be a bit different and hopefully simpler:

Compare 2 things:

  1. AdamOptimizer with batch size 32 for 1 step
  2. MultistepAdamOptimizer with batch size 8 for 4 steps with n=4

We should see that the updates are identical (i.e. the variables end up in the same place)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm am I not doing something like this? Just that I don't compare batch sizes but number of updates. For example, I compare the variables after

  1. Adam: t steps with averaged gradients over n steps
  2. MultistepAdam: t*n steps

for n=1..4 and t=1..3. Adam is implemented in numpy to avoid introducing dependencies from this test class to other parts of the TF code (like in the original Adam test)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the original Adam test uses the numpy because it's actually checking the mathematical accuracy of the implementation. Here we want to ensure that the MultistepAdamOptimizer is a drop-in replacement for AdamOptimizer and simulates a larger batch size, so I think the most clear and useful test would test exactly that. Do you agree?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright np I'll change it

@googlebot googlebot added the cla: yes PR author has signed CLA label Jun 3, 2018
@fstahlberg
Copy link
Contributor Author

Test updated

@rsepassi
Copy link
Contributor

rsepassi commented Jun 5, 2018

Looks great! Thanks so much for this contribution @fstahlberg. Really good work.

@rsepassi rsepassi merged commit 64e1df1 into tensorflow:master Jun 5, 2018
tensorflow-copybara pushed a commit that referenced this pull request Jun 5, 2018
PiperOrigin-RevId: 199354554
whr94621 pushed a commit to whr94621/tensor2tensor that referenced this pull request Jun 12, 2018
…dware (tensorflow#754)

Simulates n times more GPUs at cost of n times more training iterations
whr94621 pushed a commit to whr94621/tensor2tensor that referenced this pull request Jun 12, 2018
PiperOrigin-RevId: 199354554
@@ -64,6 +64,8 @@ def basic_params1():
optimizer_adafactor_memory_exponent=0.8,
optimizer_adafactor_clipping_threshold=1.0,
optimizer_adafactor_multiply_by_parameter_scale=True,
# Number of accumulating steps for multi step optimizers.
optimizer_multistep_accumulate_steps=None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the default value should be 1 instead of None. Otherwise "NoneType takes no arguments" error will occur when parsing the value from --hparams flag.

@nxphi47
Copy link

nxphi47 commented Sep 29, 2018

Hello, with hparams, how many --train_steps do we set in the training script to replicate exactly 100000 steps with 8 real gpus in the transformer paper?

Is it still --train_steps=100000 or --train_steps=800000 ?

@fstahlberg
Copy link
Contributor Author

@nxphi47 It is --train_steps=800000

@XiaoqingNLP
Copy link
Contributor

@fstahlberg Thank you so much and I want to know How about the performence please ?

@fstahlberg
Copy link
Contributor Author

@zxqchat For example, if you set optimizer_multistep_accumulate_steps to 8 and multiply train_steps with 8 you get the same performance as with 8 times more GPUs.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
cla: yes PR author has signed CLA
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants