diff --git a/examples/mnist/lenet_solver_adam.prototxt b/examples/mnist/lenet_solver_adam.prototxt new file mode 100644 index 00000000000..d22c5718f3f --- /dev/null +++ b/examples/mnist/lenet_solver_adam.prototxt @@ -0,0 +1,26 @@ +# The train/test net protocol buffer definition +# this follows "ADAM: A METHOD FOR STOCHASTIC OPTIMIZATION" +net: "examples/mnist/lenet_train_test.prototxt" +# test_iter specifies how many forward passes the test should carry out. +# In the case of MNIST, we have test batch size 100 and 100 test iterations, +# covering the full 10,000 testing images. +test_iter: 100 +# Carry out testing every 500 training iterations. +test_interval: 500 +# All parameters are from the cited paper above +base_lr: 0.001 +momentum: 0.9 +momentum2: 0.999 +# since Adam dynamically changes the learning rate, we set the base learning +# rate to a fixed value +lr_policy: "fixed" +# Display every 100 iterations +display: 100 +# The maximum number of iterations +max_iter: 10000 +# snapshot intermediate results +snapshot: 5000 +snapshot_prefix: "examples/mnist/lenet" +# solver mode: CPU or GPU +solver_type: ADAM +solver_mode: GPU diff --git a/examples/mnist/train_lenet_adam.sh b/examples/mnist/train_lenet_adam.sh new file mode 100755 index 00000000000..a32ecf2d9c2 --- /dev/null +++ b/examples/mnist/train_lenet_adam.sh @@ -0,0 +1,3 @@ +#!/usr/bin/env sh + +./build/tools/caffe train --solver=examples/mnist/lenet_solver_adam.prototxt diff --git a/include/caffe/solver.hpp b/include/caffe/solver.hpp index d2b99923f23..582aa1427d3 100644 --- a/include/caffe/solver.hpp +++ b/include/caffe/solver.hpp @@ -217,6 +217,21 @@ class AdaDeltaSolver : public SGDSolver { DISABLE_COPY_AND_ASSIGN(AdaDeltaSolver); }; +template +class AdamSolver : public SGDSolver { + public: + explicit AdamSolver(const SolverParameter& param) + : SGDSolver(param) { AdamPreSolve();} + explicit AdamSolver(const string& param_file) + : SGDSolver(param_file) { AdamPreSolve(); } + + protected: + void AdamPreSolve(); + virtual void ComputeUpdateValue(int param_id, Dtype rate); + + DISABLE_COPY_AND_ASSIGN(AdamSolver); +}; + template Solver* GetSolver(const SolverParameter& param) { SolverParameter_SolverType type = param.solver_type(); @@ -232,6 +247,8 @@ Solver* GetSolver(const SolverParameter& param) { return new RMSPropSolver(param); case SolverParameter_SolverType_ADADELTA: return new AdaDeltaSolver(param); + case SolverParameter_SolverType_ADAM: + return new AdamSolver(param); default: LOG(FATAL) << "Unknown SolverType: " << type; } diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto index fc0d961abda..d4c97d2bd06 100644 --- a/src/caffe/proto/caffe.proto +++ b/src/caffe/proto/caffe.proto @@ -98,7 +98,7 @@ message NetParameter { // NOTE // Update the next available ID when you add a new SolverParameter field. // -// SolverParameter next available ID: 39 (last added: rms_decay) +// SolverParameter next available ID: 40 (last added: momentum2) message SolverParameter { ////////////////////////////////////////////////////////////////////////////// // Specifying the train and test networks @@ -216,10 +216,13 @@ message SolverParameter { ADAGRAD = 2; RMSPROP = 3; ADADELTA = 4; + ADAM = 5; } optional SolverType solver_type = 30 [default = SGD]; - // numerical stability for AdaGrad + // numerical stability for RMSProp, AdaGrad and AdaDelta and Adam optional float delta = 31 [default = 1e-8]; + // parameters for the Adam solver + optional float momentum2 = 39 [default = 0.999]; // RMSProp decay value // MeanSquare(t) = rms_decay*MeanSquare(t-1) + (1-rms_decay)*SquareGradient(t) diff --git a/src/caffe/solver.cpp b/src/caffe/solver.cpp index 248f238eb76..ef88d9d93f1 100644 --- a/src/caffe/solver.cpp +++ b/src/caffe/solver.cpp @@ -1114,11 +1114,116 @@ void AdaDeltaSolver::ComputeUpdateValue(int param_id, Dtype rate) { } } +template +void AdamSolver::AdamPreSolve() { + // Add the extra history entries for Adam after those from + // SGDSolver::PreSolve + const vector*>& net_params = this->net_->learnable_params(); + for (int i = 0; i < net_params.size(); ++i) { + const vector& shape = net_params[i]->shape(); + this->history_.push_back( + shared_ptr >(new Blob(shape))); + } +} + +template +void AdamSolver::ComputeUpdateValue(int param_id, Dtype rate) { + const vector*>& net_params = this->net_->learnable_params(); + const vector& net_params_lr = this->net_->params_lr(); + Dtype local_rate = rate * net_params_lr[param_id]; + const Dtype beta1 = this->param_.momentum(); + const Dtype beta2 = this->param_.momentum2(); + + // we create aliases for convenience + size_t update_history_offset = net_params.size(); + shared_ptr > val_m = this->history_[param_id]; + shared_ptr > val_v = + this->history_[param_id + update_history_offset]; + shared_ptr > val_t = this->temp_[param_id]; + + const int t = this->iter_ + 1; + const Dtype correction = std::sqrt(Dtype(1)-pow(beta2, t))/ + (Dtype(1.)-pow(beta1, t)); + const int N = net_params[param_id]->count(); + const Dtype eps_hat = this->param_.delta(); + + switch (Caffe::mode()) { + case Caffe::CPU: { + // update m <- \beta_1 m_{t-1} + (1-\beta_1)g_t + caffe_cpu_axpby(N, Dtype(1)-beta1, + net_params[param_id]->cpu_diff(), beta1, + val_m->mutable_cpu_data()); + + // update v <- \beta_2 m_{t-1} + (1-\beta_2)g_t^2 + caffe_mul(N, + net_params[param_id]->cpu_diff(), + net_params[param_id]->cpu_diff(), + val_t->mutable_cpu_data()); + caffe_cpu_axpby(N, Dtype(1)-beta2, + val_t->cpu_data(), beta2, + val_v->mutable_cpu_data()); + + // set update + caffe_powx(N, + val_v->cpu_data(), Dtype(0.5), + val_t->mutable_cpu_data()); + caffe_add_scalar(N, eps_hat, val_t->mutable_cpu_data()); + caffe_div(N, + val_m->cpu_data(), + val_t->cpu_data(), + val_t->mutable_cpu_data()); + + caffe_cpu_scale(N, local_rate*correction, + val_t->cpu_data(), + net_params[param_id]->mutable_cpu_diff()); + break; + } + case Caffe::GPU: { +#ifndef CPU_ONLY + // update m <- \beta_1 m_{t-1} + (1-\beta_1)g_t + caffe_gpu_axpby(N, Dtype(1)-beta1, + net_params[param_id]->gpu_diff(), beta1, + val_m->mutable_gpu_data()); + + // update v <- \beta_2 m_{t-1} + (1-\beta_2)g_t^2 + caffe_gpu_mul(N, + net_params[param_id]->gpu_diff(), + net_params[param_id]->gpu_diff(), + val_t->mutable_gpu_data()); + caffe_gpu_axpby(N, Dtype(1)-beta2, + val_t->gpu_data(), beta2, + val_v->mutable_gpu_data()); + + // set update + caffe_gpu_powx(N, + val_v->gpu_data(), Dtype(0.5), + val_t->mutable_gpu_data()); + caffe_gpu_add_scalar(N, eps_hat, + val_t->mutable_gpu_data()); + caffe_gpu_div(N, + val_m->gpu_data(), + val_t->gpu_data(), + val_t->mutable_gpu_data()); + + caffe_gpu_scale(N, local_rate*correction, + val_t->gpu_data(), + net_params[param_id]->mutable_gpu_diff()); +#else + NO_GPU; +#endif + break; + } + default: + LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode(); + } +} + INSTANTIATE_CLASS(Solver); INSTANTIATE_CLASS(SGDSolver); INSTANTIATE_CLASS(NesterovSolver); INSTANTIATE_CLASS(AdaGradSolver); INSTANTIATE_CLASS(RMSPropSolver); INSTANTIATE_CLASS(AdaDeltaSolver); +INSTANTIATE_CLASS(AdamSolver); } // namespace caffe diff --git a/src/caffe/test/test_gradient_based_solver.cpp b/src/caffe/test/test_gradient_based_solver.cpp index 1d255a86621..fac4a9aa679 100644 --- a/src/caffe/test/test_gradient_based_solver.cpp +++ b/src/caffe/test/test_gradient_based_solver.cpp @@ -67,7 +67,8 @@ class GradientBasedSolverTest : public MultiDeviceTest { InitSolver(param); delta_ = (solver_type() == SolverParameter_SolverType_ADAGRAD || solver_type() == SolverParameter_SolverType_RMSPROP || - solver_type() == SolverParameter_SolverType_ADADELTA) ? + solver_type() == SolverParameter_SolverType_ADADELTA || + solver_type() == SolverParameter_SolverType_ADAM) ? param.delta() : 0; } @@ -253,6 +254,7 @@ class GradientBasedSolverTest : public MultiDeviceTest { updated_bias.ReshapeLike(bias); for (int i = 0; i <= D; ++i) { + // Compute the derivative with respect to the ith weight (i.e., the ith // element of the gradient). Dtype grad = 0; @@ -275,6 +277,7 @@ class GradientBasedSolverTest : public MultiDeviceTest { const Dtype element_i = (i == D) ? 1 : data.cpu_data()[k * D + i]; grad -= element_i * targets.cpu_data()[k]; } + // Scale the gradient over the N samples. grad /= N; // Add the weight decay to the gradient. @@ -282,7 +285,8 @@ class GradientBasedSolverTest : public MultiDeviceTest { ((i == D) ? bias.cpu_data()[0] : weights.cpu_data()[i]); // Finally, compute update. const vector > >& history = solver_->history(); - if (solver_type() != SolverParameter_SolverType_ADADELTA) { + if (solver_type() != SolverParameter_SolverType_ADADELTA + && solver_type() != SolverParameter_SolverType_ADAM) { ASSERT_EQ(2, history.size()); // 1 blob for weights, 1 for bias } else { ASSERT_EQ(4, history.size()); // additional blobs for update history @@ -322,6 +326,18 @@ class GradientBasedSolverTest : public MultiDeviceTest { // momentum * update_history_value + (1 - momentum) * (update_value); break; } + case SolverParameter_SolverType_ADAM: { + const Dtype momentum2 = 0.999; + const Dtype m = (i == D) ? + history[1]->cpu_data()[0] : history[0]->cpu_data()[i]; + const Dtype v = (i == D) ? + history[3]->cpu_data()[0] : history[2]->cpu_data()[i]; + const Dtype val_m = (1-momentum)*grad + momentum*m; + const Dtype val_v = (1-momentum2)*grad*grad + momentum2*v; + Dtype alpha_t = std::sqrt(1-momentum2)/(1-momentum); + update_value = alpha_t * val_m / (std::sqrt(val_v) + delta_); + break; + } default: LOG(FATAL) << "Unknown solver type: " << solver_type(); } @@ -1061,6 +1077,111 @@ TYPED_TEST(AdaDeltaSolverTest, TestSnapshotShare) { } } +template +class AdamSolverTest : public GradientBasedSolverTest { + typedef typename TypeParam::Dtype Dtype; + + protected: + virtual void InitSolver(const SolverParameter& param) { + SolverParameter new_param = param; + const Dtype momentum = 0.9; + new_param.set_momentum(momentum); + const Dtype momentum2 = 0.999; + new_param.set_momentum2(momentum2); + this->solver_.reset(new AdamSolver(new_param)); + } + virtual SolverParameter_SolverType solver_type() { + return SolverParameter_SolverType_ADAM; + } +}; + +TYPED_TEST_CASE(AdamSolverTest, TestDtypesAndDevices); + +TYPED_TEST(AdamSolverTest, TestAdamLeastSquaresUpdate) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 0.01; + const Dtype kWeightDecay = 0.0; + const Dtype kMomentum = 0.9; + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum); +} + +TYPED_TEST(AdamSolverTest, TestAdaDeltaLeastSquaresUpdateWithWeightDecay) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 0.01; + const Dtype kWeightDecay = 0.5; + const Dtype kMomentum = 0.9; + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum); +} + +TYPED_TEST(AdamSolverTest, TestAdaDeltaLeastSquaresUpdateWithEverything) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 0.01; + const Dtype kWeightDecay = 0.5; + const Dtype kMomentum = 0.9; + const int kNumIters = 4; + for (int i = 0; i <= kNumIters; ++i) { + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i); + } +} + +TYPED_TEST(AdamSolverTest, TestAdaDeltaLeastSquaresUpdateWithEverythingShare) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 0.01; + const Dtype kWeightDecay = 0.5; + const Dtype kMomentum = 0.9; + const int kNumIters = 4; + this->share_ = true; + for (int i = 0; i <= kNumIters; ++i) { + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i); + } +} + +TYPED_TEST(AdamSolverTest, TestLeastSquaresUpdateWithEverythingAccum) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 0.01; + const Dtype kWeightDecay = 0.5; + const Dtype kMomentum = 0.9; + const int kNumIters = 4; + const int kIterSize = 2; + this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters, + kIterSize); +} + +TYPED_TEST(AdamSolverTest, TestLeastSquaresUpdateWithEverythingAccumShare) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 0.01; + const Dtype kWeightDecay = 0.5; + const Dtype kMomentum = 0.9; + const int kNumIters = 4; + const int kIterSize = 2; + this->share_ = true; + this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters, + kIterSize); +} + +TYPED_TEST(AdamSolverTest, TestSnapshot) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 0.01; + const Dtype kWeightDecay = 0.5; + const Dtype kMomentum = 0.9; + const int kNumIters = 4; + for (int i = 1; i <= kNumIters; ++i) { + this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i); + } +} + +TYPED_TEST(AdamSolverTest, TestSnapshotShare) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 0.01; + const Dtype kWeightDecay = 0.5; + const Dtype kMomentum = 0.9; + const int kNumIters = 4; + this->share_ = true; + for (int i = 1; i <= kNumIters; ++i) { + this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i); + } +} + template class RMSPropSolverTest : public GradientBasedSolverTest { typedef typename TypeParam::Dtype Dtype;