From c482eb00d9ee4ef0571fa71640113f9abaa7b083 Mon Sep 17 00:00:00 2001 From: PatWie Date: Mon, 3 Aug 2015 17:31:14 +0200 Subject: [PATCH] Adam solver This commit implements the Adam solver by Kingma et. al for CPU and GPU. All solver parameters are defined in the caffe.proto. This also adds an example for the MNIST dataset. --- examples/mnist/lenet_solver_adam.prototxt | 26 ++++ examples/mnist/train_lenet_adam.sh | 3 + include/caffe/solver.hpp | 17 +++ src/caffe/proto/caffe.proto | 7 +- src/caffe/solver.cpp | 105 +++++++++++++++ src/caffe/test/test_gradient_based_solver.cpp | 125 +++++++++++++++++- 6 files changed, 279 insertions(+), 4 deletions(-) create mode 100644 examples/mnist/lenet_solver_adam.prototxt create mode 100755 examples/mnist/train_lenet_adam.sh diff --git a/examples/mnist/lenet_solver_adam.prototxt b/examples/mnist/lenet_solver_adam.prototxt new file mode 100644 index 00000000000..d22c5718f3f --- /dev/null +++ b/examples/mnist/lenet_solver_adam.prototxt @@ -0,0 +1,26 @@ +# The train/test net protocol buffer definition +# this follows "ADAM: A METHOD FOR STOCHASTIC OPTIMIZATION" +net: "examples/mnist/lenet_train_test.prototxt" +# test_iter specifies how many forward passes the test should carry out. +# In the case of MNIST, we have test batch size 100 and 100 test iterations, +# covering the full 10,000 testing images. +test_iter: 100 +# Carry out testing every 500 training iterations. +test_interval: 500 +# All parameters are from the cited paper above +base_lr: 0.001 +momentum: 0.9 +momentum2: 0.999 +# since Adam dynamically changes the learning rate, we set the base learning +# rate to a fixed value +lr_policy: "fixed" +# Display every 100 iterations +display: 100 +# The maximum number of iterations +max_iter: 10000 +# snapshot intermediate results +snapshot: 5000 +snapshot_prefix: "examples/mnist/lenet" +# solver mode: CPU or GPU +solver_type: ADAM +solver_mode: GPU diff --git a/examples/mnist/train_lenet_adam.sh b/examples/mnist/train_lenet_adam.sh new file mode 100755 index 00000000000..a32ecf2d9c2 --- /dev/null +++ b/examples/mnist/train_lenet_adam.sh @@ -0,0 +1,3 @@ +#!/usr/bin/env sh + +./build/tools/caffe train --solver=examples/mnist/lenet_solver_adam.prototxt diff --git a/include/caffe/solver.hpp b/include/caffe/solver.hpp index d2b99923f23..582aa1427d3 100644 --- a/include/caffe/solver.hpp +++ b/include/caffe/solver.hpp @@ -217,6 +217,21 @@ class AdaDeltaSolver : public SGDSolver { DISABLE_COPY_AND_ASSIGN(AdaDeltaSolver); }; +template +class AdamSolver : public SGDSolver { + public: + explicit AdamSolver(const SolverParameter& param) + : SGDSolver(param) { AdamPreSolve();} + explicit AdamSolver(const string& param_file) + : SGDSolver(param_file) { AdamPreSolve(); } + + protected: + void AdamPreSolve(); + virtual void ComputeUpdateValue(int param_id, Dtype rate); + + DISABLE_COPY_AND_ASSIGN(AdamSolver); +}; + template Solver* GetSolver(const SolverParameter& param) { SolverParameter_SolverType type = param.solver_type(); @@ -232,6 +247,8 @@ Solver* GetSolver(const SolverParameter& param) { return new RMSPropSolver(param); case SolverParameter_SolverType_ADADELTA: return new AdaDeltaSolver(param); + case SolverParameter_SolverType_ADAM: + return new AdamSolver(param); default: LOG(FATAL) << "Unknown SolverType: " << type; } diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto index fc0d961abda..d4c97d2bd06 100644 --- a/src/caffe/proto/caffe.proto +++ b/src/caffe/proto/caffe.proto @@ -98,7 +98,7 @@ message NetParameter { // NOTE // Update the next available ID when you add a new SolverParameter field. // -// SolverParameter next available ID: 39 (last added: rms_decay) +// SolverParameter next available ID: 40 (last added: momentum2) message SolverParameter { ////////////////////////////////////////////////////////////////////////////// // Specifying the train and test networks @@ -216,10 +216,13 @@ message SolverParameter { ADAGRAD = 2; RMSPROP = 3; ADADELTA = 4; + ADAM = 5; } optional SolverType solver_type = 30 [default = SGD]; - // numerical stability for AdaGrad + // numerical stability for RMSProp, AdaGrad and AdaDelta and Adam optional float delta = 31 [default = 1e-8]; + // parameters for the Adam solver + optional float momentum2 = 39 [default = 0.999]; // RMSProp decay value // MeanSquare(t) = rms_decay*MeanSquare(t-1) + (1-rms_decay)*SquareGradient(t) diff --git a/src/caffe/solver.cpp b/src/caffe/solver.cpp index 248f238eb76..ef88d9d93f1 100644 --- a/src/caffe/solver.cpp +++ b/src/caffe/solver.cpp @@ -1114,11 +1114,116 @@ void AdaDeltaSolver::ComputeUpdateValue(int param_id, Dtype rate) { } } +template +void AdamSolver::AdamPreSolve() { + // Add the extra history entries for Adam after those from + // SGDSolver::PreSolve + const vector*>& net_params = this->net_->learnable_params(); + for (int i = 0; i < net_params.size(); ++i) { + const vector& shape = net_params[i]->shape(); + this->history_.push_back( + shared_ptr >(new Blob(shape))); + } +} + +template +void AdamSolver::ComputeUpdateValue(int param_id, Dtype rate) { + const vector*>& net_params = this->net_->learnable_params(); + const vector& net_params_lr = this->net_->params_lr(); + Dtype local_rate = rate * net_params_lr[param_id]; + const Dtype beta1 = this->param_.momentum(); + const Dtype beta2 = this->param_.momentum2(); + + // we create aliases for convenience + size_t update_history_offset = net_params.size(); + shared_ptr > val_m = this->history_[param_id]; + shared_ptr > val_v = + this->history_[param_id + update_history_offset]; + shared_ptr > val_t = this->temp_[param_id]; + + const int t = this->iter_ + 1; + const Dtype correction = std::sqrt(Dtype(1)-pow(beta2, t))/ + (Dtype(1.)-pow(beta1, t)); + const int N = net_params[param_id]->count(); + const Dtype eps_hat = this->param_.delta(); + + switch (Caffe::mode()) { + case Caffe::CPU: { + // update m <- \beta_1 m_{t-1} + (1-\beta_1)g_t + caffe_cpu_axpby(N, Dtype(1)-beta1, + net_params[param_id]->cpu_diff(), beta1, + val_m->mutable_cpu_data()); + + // update v <- \beta_2 m_{t-1} + (1-\beta_2)g_t^2 + caffe_mul(N, + net_params[param_id]->cpu_diff(), + net_params[param_id]->cpu_diff(), + val_t->mutable_cpu_data()); + caffe_cpu_axpby(N, Dtype(1)-beta2, + val_t->cpu_data(), beta2, + val_v->mutable_cpu_data()); + + // set update + caffe_powx(N, + val_v->cpu_data(), Dtype(0.5), + val_t->mutable_cpu_data()); + caffe_add_scalar(N, eps_hat, val_t->mutable_cpu_data()); + caffe_div(N, + val_m->cpu_data(), + val_t->cpu_data(), + val_t->mutable_cpu_data()); + + caffe_cpu_scale(N, local_rate*correction, + val_t->cpu_data(), + net_params[param_id]->mutable_cpu_diff()); + break; + } + case Caffe::GPU: { +#ifndef CPU_ONLY + // update m <- \beta_1 m_{t-1} + (1-\beta_1)g_t + caffe_gpu_axpby(N, Dtype(1)-beta1, + net_params[param_id]->gpu_diff(), beta1, + val_m->mutable_gpu_data()); + + // update v <- \beta_2 m_{t-1} + (1-\beta_2)g_t^2 + caffe_gpu_mul(N, + net_params[param_id]->gpu_diff(), + net_params[param_id]->gpu_diff(), + val_t->mutable_gpu_data()); + caffe_gpu_axpby(N, Dtype(1)-beta2, + val_t->gpu_data(), beta2, + val_v->mutable_gpu_data()); + + // set update + caffe_gpu_powx(N, + val_v->gpu_data(), Dtype(0.5), + val_t->mutable_gpu_data()); + caffe_gpu_add_scalar(N, eps_hat, + val_t->mutable_gpu_data()); + caffe_gpu_div(N, + val_m->gpu_data(), + val_t->gpu_data(), + val_t->mutable_gpu_data()); + + caffe_gpu_scale(N, local_rate*correction, + val_t->gpu_data(), + net_params[param_id]->mutable_gpu_diff()); +#else + NO_GPU; +#endif + break; + } + default: + LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode(); + } +} + INSTANTIATE_CLASS(Solver); INSTANTIATE_CLASS(SGDSolver); INSTANTIATE_CLASS(NesterovSolver); INSTANTIATE_CLASS(AdaGradSolver); INSTANTIATE_CLASS(RMSPropSolver); INSTANTIATE_CLASS(AdaDeltaSolver); +INSTANTIATE_CLASS(AdamSolver); } // namespace caffe diff --git a/src/caffe/test/test_gradient_based_solver.cpp b/src/caffe/test/test_gradient_based_solver.cpp index 1d255a86621..fac4a9aa679 100644 --- a/src/caffe/test/test_gradient_based_solver.cpp +++ b/src/caffe/test/test_gradient_based_solver.cpp @@ -67,7 +67,8 @@ class GradientBasedSolverTest : public MultiDeviceTest { InitSolver(param); delta_ = (solver_type() == SolverParameter_SolverType_ADAGRAD || solver_type() == SolverParameter_SolverType_RMSPROP || - solver_type() == SolverParameter_SolverType_ADADELTA) ? + solver_type() == SolverParameter_SolverType_ADADELTA || + solver_type() == SolverParameter_SolverType_ADAM) ? param.delta() : 0; } @@ -253,6 +254,7 @@ class GradientBasedSolverTest : public MultiDeviceTest { updated_bias.ReshapeLike(bias); for (int i = 0; i <= D; ++i) { + // Compute the derivative with respect to the ith weight (i.e., the ith // element of the gradient). Dtype grad = 0; @@ -275,6 +277,7 @@ class GradientBasedSolverTest : public MultiDeviceTest { const Dtype element_i = (i == D) ? 1 : data.cpu_data()[k * D + i]; grad -= element_i * targets.cpu_data()[k]; } + // Scale the gradient over the N samples. grad /= N; // Add the weight decay to the gradient. @@ -282,7 +285,8 @@ class GradientBasedSolverTest : public MultiDeviceTest { ((i == D) ? bias.cpu_data()[0] : weights.cpu_data()[i]); // Finally, compute update. const vector > >& history = solver_->history(); - if (solver_type() != SolverParameter_SolverType_ADADELTA) { + if (solver_type() != SolverParameter_SolverType_ADADELTA + && solver_type() != SolverParameter_SolverType_ADAM) { ASSERT_EQ(2, history.size()); // 1 blob for weights, 1 for bias } else { ASSERT_EQ(4, history.size()); // additional blobs for update history @@ -322,6 +326,18 @@ class GradientBasedSolverTest : public MultiDeviceTest { // momentum * update_history_value + (1 - momentum) * (update_value); break; } + case SolverParameter_SolverType_ADAM: { + const Dtype momentum2 = 0.999; + const Dtype m = (i == D) ? + history[1]->cpu_data()[0] : history[0]->cpu_data()[i]; + const Dtype v = (i == D) ? + history[3]->cpu_data()[0] : history[2]->cpu_data()[i]; + const Dtype val_m = (1-momentum)*grad + momentum*m; + const Dtype val_v = (1-momentum2)*grad*grad + momentum2*v; + Dtype alpha_t = std::sqrt(1-momentum2)/(1-momentum); + update_value = alpha_t * val_m / (std::sqrt(val_v) + delta_); + break; + } default: LOG(FATAL) << "Unknown solver type: " << solver_type(); } @@ -1061,6 +1077,111 @@ TYPED_TEST(AdaDeltaSolverTest, TestSnapshotShare) { } } +template +class AdamSolverTest : public GradientBasedSolverTest { + typedef typename TypeParam::Dtype Dtype; + + protected: + virtual void InitSolver(const SolverParameter& param) { + SolverParameter new_param = param; + const Dtype momentum = 0.9; + new_param.set_momentum(momentum); + const Dtype momentum2 = 0.999; + new_param.set_momentum2(momentum2); + this->solver_.reset(new AdamSolver(new_param)); + } + virtual SolverParameter_SolverType solver_type() { + return SolverParameter_SolverType_ADAM; + } +}; + +TYPED_TEST_CASE(AdamSolverTest, TestDtypesAndDevices); + +TYPED_TEST(AdamSolverTest, TestAdamLeastSquaresUpdate) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 0.01; + const Dtype kWeightDecay = 0.0; + const Dtype kMomentum = 0.9; + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum); +} + +TYPED_TEST(AdamSolverTest, TestAdaDeltaLeastSquaresUpdateWithWeightDecay) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 0.01; + const Dtype kWeightDecay = 0.5; + const Dtype kMomentum = 0.9; + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum); +} + +TYPED_TEST(AdamSolverTest, TestAdaDeltaLeastSquaresUpdateWithEverything) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 0.01; + const Dtype kWeightDecay = 0.5; + const Dtype kMomentum = 0.9; + const int kNumIters = 4; + for (int i = 0; i <= kNumIters; ++i) { + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i); + } +} + +TYPED_TEST(AdamSolverTest, TestAdaDeltaLeastSquaresUpdateWithEverythingShare) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 0.01; + const Dtype kWeightDecay = 0.5; + const Dtype kMomentum = 0.9; + const int kNumIters = 4; + this->share_ = true; + for (int i = 0; i <= kNumIters; ++i) { + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i); + } +} + +TYPED_TEST(AdamSolverTest, TestLeastSquaresUpdateWithEverythingAccum) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 0.01; + const Dtype kWeightDecay = 0.5; + const Dtype kMomentum = 0.9; + const int kNumIters = 4; + const int kIterSize = 2; + this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters, + kIterSize); +} + +TYPED_TEST(AdamSolverTest, TestLeastSquaresUpdateWithEverythingAccumShare) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 0.01; + const Dtype kWeightDecay = 0.5; + const Dtype kMomentum = 0.9; + const int kNumIters = 4; + const int kIterSize = 2; + this->share_ = true; + this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters, + kIterSize); +} + +TYPED_TEST(AdamSolverTest, TestSnapshot) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 0.01; + const Dtype kWeightDecay = 0.5; + const Dtype kMomentum = 0.9; + const int kNumIters = 4; + for (int i = 1; i <= kNumIters; ++i) { + this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i); + } +} + +TYPED_TEST(AdamSolverTest, TestSnapshotShare) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 0.01; + const Dtype kWeightDecay = 0.5; + const Dtype kMomentum = 0.9; + const int kNumIters = 4; + this->share_ = true; + for (int i = 1; i <= kNumIters; ++i) { + this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i); + } +} + template class RMSPropSolverTest : public GradientBasedSolverTest { typedef typename TypeParam::Dtype Dtype;