From b7ab5b7e96cb9ef4a054b87175c0004b717e9f09 Mon Sep 17 00:00:00 2001 From: fstahlberg Date: Tue, 5 Jun 2018 12:56:39 -0700 Subject: [PATCH] Add MultistepAdamOptimizer: Large training batches on limited GPU hardware (#754) Simulates n times more GPUs at cost of n times more training iterations --- tensor2tensor/layers/common_hparams.py | 2 + tensor2tensor/models/transformer.py | 9 ++ tensor2tensor/utils/learning_rate.py | 17 ++- tensor2tensor/utils/multistep_optimizer.py | 139 ++++++++++++++++++ .../utils/multistep_optimizer_test.py | 106 +++++++++++++ tensor2tensor/utils/optimize.py | 8 + 6 files changed, 278 insertions(+), 3 deletions(-) create mode 100644 tensor2tensor/utils/multistep_optimizer.py create mode 100644 tensor2tensor/utils/multistep_optimizer_test.py diff --git a/tensor2tensor/layers/common_hparams.py b/tensor2tensor/layers/common_hparams.py index d3ae28d7d..986a220e3 100644 --- a/tensor2tensor/layers/common_hparams.py +++ b/tensor2tensor/layers/common_hparams.py @@ -64,6 +64,8 @@ def basic_params1(): optimizer_adafactor_memory_exponent=0.8, optimizer_adafactor_clipping_threshold=1.0, optimizer_adafactor_multiply_by_parameter_scale=True, + # Number of accumulating steps for multi step optimizers. + optimizer_multistep_accumulate_steps=None, weight_decay=1e-6, weight_noise=0.0, # Defines the learning rate as a product of named functions. diff --git a/tensor2tensor/models/transformer.py b/tensor2tensor/models/transformer.py index f21f2936f..86ed71eb7 100644 --- a/tensor2tensor/models/transformer.py +++ b/tensor2tensor/models/transformer.py @@ -1136,6 +1136,15 @@ def transformer_base_single_gpu(): return hparams +@registry.register_hparams +def transformer_base_multistep8(): + """HParams for simulating 8 GPUs with MultistepAdam optimizer.""" + hparams = transformer_base() + hparams.optimizer = "MultistepAdam" + hparams.optimizer_multistep_accumulate_steps = 8 + return hparams + + @registry.register_hparams def transformer_parsing_base(): """HParams for parsing on WSJ only.""" diff --git a/tensor2tensor/utils/learning_rate.py b/tensor2tensor/utils/learning_rate.py index 843e484ab..bc4894e20 100644 --- a/tensor2tensor/utils/learning_rate.py +++ b/tensor2tensor/utils/learning_rate.py @@ -40,7 +40,7 @@ def learning_rate_factor(name, step_num, hparams): def learning_rate_schedule(hparams): """Learning rate schedule based on hparams.""" - step_num = tf.to_float(tf.train.get_or_create_global_step()) + step_num = _global_step(hparams) schedule_string = hparams.learning_rate_schedule names = schedule_string.split("*") names = [name.strip() for name in names if name.strip()] @@ -52,7 +52,7 @@ def learning_rate_schedule(hparams): def legacy_learning_rate_schedule(hparams): """Backwards-compatible learning-rate schedule.""" - step_num = tf.to_float(tf.train.get_or_create_global_step()) + step_num = _global_step(hparams) warmup_steps = tf.to_float(hparams.learning_rate_warmup_steps) if hparams.learning_rate_decay_scheme == "noam": ret = 5000.0 * hparams.hidden_size**-0.5 * tf.minimum( @@ -67,6 +67,17 @@ def legacy_learning_rate_schedule(hparams): return ret * optimizer_correction * hparams.learning_rate +def _global_step(hparams): + """Adjust global step if a multi-step optimizer is used.""" + step = tf.to_float(tf.train.get_or_create_global_step()) + multiplier = hparams.optimizer_multistep_accumulate_steps + if multiplier: + step = step / tf.to_float(multiplier) + tf.logging.info("Divided global step by %d for multi-step optimizer." + % multiplier) + return step + + def _legacy_sqrt_decay(step): """Decay like 1 / sqrt(step), multiplied by 500 to normalize.""" return 500.0 / tf.sqrt(tf.maximum(step, 1.0)) @@ -95,7 +106,7 @@ def _learning_rate_decay(hparams, warmup_steps=0): """Learning rate decay multiplier.""" scheme = hparams.learning_rate_decay_scheme warmup_steps = tf.to_float(warmup_steps) - global_step = tf.to_float(tf.train.get_or_create_global_step()) + global_step = _global_step(hparams) if not scheme or scheme == "none": return tf.constant(1.) diff --git a/tensor2tensor/utils/multistep_optimizer.py b/tensor2tensor/utils/multistep_optimizer.py new file mode 100644 index 000000000..9f25349eb --- /dev/null +++ b/tensor2tensor/utils/multistep_optimizer.py @@ -0,0 +1,139 @@ +# coding=utf-8 +# Copyright 2018 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Optimizer variants which make it possible to use very large batch sizes with +limited GPU memory. Optimizers in this module accumulate the gradients for n +batches, and call the optimizer's update rule every n batches with the +accumulated gradients. + +See [Saunders et al., 2018](https://arxiv.org/abs/1805.00456) for details. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# Dependency imports + +import tensorflow as tf + + +class MultistepAdamOptimizer(tf.train.AdamOptimizer): + """Adam with SGD updates every n steps with accumulated gradients.""" + + def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8, + use_locking=False, name="Adam", n=1): + super(MultistepAdamOptimizer, self).__init__( + learning_rate=learning_rate, beta1=beta1, beta2=beta2, epsilon=epsilon, + use_locking=use_locking, name=name) + self._n = n # Call Adam optimizer every n batches with accumulated grads + self._n_t = None # n as tensor + + def _create_slots(self, var_list): + """Create slot variables for Adam with accumulated gradients. + + Like super class method, but additionally creates slots for the gradient + accumulator `acc_grad` and the counter variable. + """ + super(MultistepAdamOptimizer, self)._create_slots(var_list) + first_var = min(var_list, key=lambda x: x.name) + self._create_non_slot_variable(initial_value=0 if self._n == 1 else 1, + name="iter", + colocate_with=first_var) + for v in var_list: + self._zeros_slot(v, "grad_acc", self._name) + + def _get_iter_variable(self): + if tf.contrib.eager.in_eager_mode(): + graph = None + else: + graph = tf.get_default_graph() + return self._get_non_slot_variable("iter", graph=graph) + + def _prepare(self): + super(MultistepAdamOptimizer, self)._prepare() + self._n_t = tf.convert_to_tensor(self._n, name="n") + + def _apply_cond(self, apply_fn, grad, var, *args, **kwargs): + """Conditionally apply or accumulate gradient. + + Call `apply_fn only if the current counter value (iter) is zero. This + method couples common functionality for all _apply_*() implementations + in Adam. + """ + grad_acc = self.get_slot(var, "grad_acc") + + def apply_adam(grad_acc, apply_fn, grad, var, *args, **kwargs): + total_grad = (grad_acc + grad) / tf.cast(self._n_t, grad.dtype) + adam_op = apply_fn(total_grad, var, *args, **kwargs) + with tf.control_dependencies([adam_op]): + grad_acc_to_zero_op = grad_acc.assign(tf.zeros_like(grad_acc), + use_locking=self._use_locking) + return tf.group(adam_op, grad_acc_to_zero_op) + + def accumulate_gradient(grad_acc, grad): + assign_op = tf.assign_add(grad_acc, grad, use_locking=self._use_locking) + return tf.group(assign_op) # Strip return value + + return tf.cond(tf.equal(self._get_iter_variable(), 0), + lambda: apply_adam( + grad_acc, apply_fn, grad, var, *args, **kwargs), + lambda: accumulate_gradient(grad_acc, grad)) + + def _apply_dense(self, grad, var): + return self._apply_cond( + super(MultistepAdamOptimizer, self)._apply_dense, grad, var) + + def _resource_apply_dense(self, grad, var): + return self._apply_cond( + super(MultistepAdamOptimizer, self)._resource_apply_dense, grad, var) + + def _apply_sparse_shared(self, grad, var, indices, scatter_add): + return self._apply_cond( + super(MultistepAdamOptimizer, self)._apply_sparse_shared, grad, var, + indices, scatter_add) + + def _apply_sparse(self, grad, var): + # TODO: Implement a sparse version + dense_grad = tf.convert_to_tensor(grad) + return self._apply_cond( + super(MultistepAdamOptimizer, self)._apply_dense, dense_grad, var) + + def _finish(self, update_ops, name_scope): + """Like super class method, but updates beta_power variables only every + n batches. The iter variable is updated with + + iter <- iter + 1 mod n + """ + iter_ = self._get_iter_variable() + beta1_power, beta2_power = self._get_beta_accumulators() + with tf.control_dependencies(update_ops): + with tf.colocate_with(iter_): + + def update_beta_op(): + update_beta1 = beta1_power.assign( + beta1_power * self._beta1_t, + use_locking=self._use_locking) + update_beta2 = beta2_power.assign( + beta2_power * self._beta2_t, + use_locking=self._use_locking) + return tf.group(update_beta1, update_beta2) + maybe_update_beta = tf.cond( + tf.equal(iter_, 0), update_beta_op, tf.no_op) + with tf.control_dependencies([maybe_update_beta]): + update_iter = iter_.assign(tf.mod(iter_ + 1, self._n_t), + use_locking=self._use_locking) + return tf.group( + *update_ops + [update_iter, maybe_update_beta], name=name_scope) + diff --git a/tensor2tensor/utils/multistep_optimizer_test.py b/tensor2tensor/utils/multistep_optimizer_test.py new file mode 100644 index 000000000..0cfc60482 --- /dev/null +++ b/tensor2tensor/utils/multistep_optimizer_test.py @@ -0,0 +1,106 @@ +# coding=utf-8 +# Copyright 2018 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Multi-step Optimizer Test Module for TensorFlow.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# Dependency imports + +import numpy as np +import tensorflow as tf +from tensor2tensor.utils.multistep_optimizer import MultistepAdamOptimizer + + +class MultistepAdamOptimizerTest(tf.test.TestCase): + + def testMultistep(self): + ver = tf.__version__.split('.') + # TODO: Remove version check once 1.5 is not tested anymore + if int(ver[0]) <= 1 and int(ver[1]) < 6: + # MultistepAdamOptimizer requires TF >= 1.6 + return + dtype = tf.float32 + beta1 = 0.2 + beta2 = 0.99 + alpha = 10.0 + grads0_np_lst = [ + np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype), + np.array([0.2, -0.1], dtype=dtype.as_numpy_dtype), + np.array([0.3, 0.1], dtype=dtype.as_numpy_dtype), + np.array([0.4, -0.1], dtype=dtype.as_numpy_dtype) + ] + grads1_np_lst = [ + np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype), + np.array([0.02, 0.02], dtype=dtype.as_numpy_dtype), + np.array([-0.04, 0.04], dtype=dtype.as_numpy_dtype), + np.array([-0.04, 0.06], dtype=dtype.as_numpy_dtype) + ] + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + # Test accumulating gradients for n=1..4 steps + for n in range(1, 5): + with self.test_session(): + with self.test_session(graph=tf.Graph()): + singlestep_var0 = tf.Variable(var0_np) + singlestep_var1 = tf.Variable(var1_np) + + multistep_var0 = tf.Variable(var0_np) + multistep_var1 = tf.Variable(var1_np) + + singlestep_opt = tf.train.AdamOptimizer( + beta1=beta1, beta2=beta2, learning_rate=alpha) + multistep_opt = MultistepAdamOptimizer( + n=n, beta1=beta1, beta2=beta2, learning_rate=alpha) + + singlestep_update = singlestep_opt.apply_gradients([ + (tf.constant(sum(grads0_np_lst[:n]) / n), singlestep_var0), + (tf.constant(sum(grads1_np_lst[:n]) / n), singlestep_var1)]) + multistep_updates = [ + multistep_opt.apply_gradients([(tf.constant(g0), multistep_var0), + (tf.constant(g1), multistep_var1)]) + for g0, g1 in zip(grads0_np_lst, grads1_np_lst)][:n] + + self.evaluate(tf.global_variables_initializer()) + (singlestep_beta1_power, + singlestep_beta2_power) = singlestep_opt._get_beta_accumulators() + (multistep_beta1_power, + multistep_beta2_power) = multistep_opt._get_beta_accumulators() + + # Run 3 steps of Adam + for _ in range(1, 4): + self.evaluate(singlestep_update) + for multistep_update in multistep_updates: + self.evaluate(multistep_update) + + self.assertAllCloseAccordingToType( + self.evaluate(singlestep_beta1_power), + self.evaluate(multistep_beta1_power)) + self.assertAllCloseAccordingToType( + self.evaluate(singlestep_beta2_power), + self.evaluate(multistep_beta2_power)) + # Validate updated params + self.assertAllCloseAccordingToType( + self.evaluate(singlestep_var0), + self.evaluate(multistep_var0)) + self.assertAllCloseAccordingToType( + self.evaluate(singlestep_var1), + self.evaluate(multistep_var1)) + + +if __name__ == "__main__": + tf.test.main() + diff --git a/tensor2tensor/utils/optimize.py b/tensor2tensor/utils/optimize.py index b973f9ed3..a64a2869a 100644 --- a/tensor2tensor/utils/optimize.py +++ b/tensor2tensor/utils/optimize.py @@ -20,6 +20,7 @@ from tensor2tensor.layers import common_layers from tensor2tensor.utils import adafactor +from tensor2tensor.utils import multistep_optimizer from tensor2tensor.utils import yellowfin import tensorflow as tf @@ -84,6 +85,13 @@ def __init__(self, optimizer_name, lr, hparams, use_tpu=False): # pylint: disab beta1=hparams.optimizer_adam_beta1, beta2=hparams.optimizer_adam_beta2, epsilon=hparams.optimizer_adam_epsilon) + elif optimizer_name == "MultistepAdam": + self._opt = multistep_optimizer.MultistepAdamOptimizer( + lr, + beta1=hparams.optimizer_adam_beta1, + beta2=hparams.optimizer_adam_beta2, + epsilon=hparams.optimizer_adam_epsilon, + n=hparams.optimizer_multistep_accumulate_steps) elif optimizer_name == "Momentum": self._opt = tf.train.MomentumOptimizer( lr,