From 41cf06cc6e40e1b41d04b5b26e19395611bdcf5d Mon Sep 17 00:00:00 2001 From: Jonathan L Long Date: Mon, 11 Aug 2014 21:38:59 -0700 Subject: [PATCH 1/8] zero-init param diffs and accumulate gradients (With layers whose backward accumulates gradients), this effectively decouples the computational batch from the SGD minibatch. Each iteration accumulates gradients over iter_size batches, then parameters are updated. --- src/caffe/proto/caffe.proto | 4 +++- src/caffe/solver.cpp | 27 ++++++++++++++++++++++++++- 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto index c471fa0a93e..94836421a42 100644 --- a/src/caffe/proto/caffe.proto +++ b/src/caffe/proto/caffe.proto @@ -96,7 +96,7 @@ message NetParameter { // NOTE // Update the next available ID when you add a new SolverParameter field. // -// SolverParameter next available ID: 36 (last added: clip_gradients) +// SolverParameter next available ID: 37 (last added: iter_size) message SolverParameter { ////////////////////////////////////////////////////////////////////////////// // Specifying the train and test networks @@ -149,6 +149,8 @@ message SolverParameter { // Display the loss averaged over the last average_loss iterations optional int32 average_loss = 33 [default = 1]; optional int32 max_iter = 7; // the maximum number of iterations + // accumulate gradients over `iter_size` x `batch_size` instances + optional int32 iter_size = 36 [default = 1]; optional string lr_policy = 8; // The learning rate decay policy. optional float gamma = 9; // The parameter to compute the learning rate. optional float power = 10; // The parameter to compute the learning rate. diff --git a/src/caffe/solver.cpp b/src/caffe/solver.cpp index fa334edaa60..ad041b8f268 100644 --- a/src/caffe/solver.cpp +++ b/src/caffe/solver.cpp @@ -168,6 +168,25 @@ void Solver::Step(int iters) { Dtype smoothed_loss = 0; while (iter_ < stop_iter) { + // zero-init the params + for (int i = 0; i < net_->params().size(); ++i) { + shared_ptr > blob = net_->params()[i]; + switch(Caffe::mode()) { + case Caffe::CPU: + caffe_set(blob->count(), static_cast(0), + blob->mutable_cpu_diff()); + break; + case Caffe::GPU: +#ifndef CPU_ONLY + caffe_gpu_set(blob->count(), static_cast(0), + blob->mutable_gpu_diff()); +#else + NO_GPU; +#endif + break; + } + } + if (param_.test_interval() && iter_ % param_.test_interval() == 0 && (iter_ > 0 || param_.test_initialization())) { TestAll(); @@ -175,7 +194,13 @@ void Solver::Step(int iters) { const bool display = param_.display() && iter_ % param_.display() == 0; net_->set_debug_info(display && param_.debug_info()); - Dtype loss = net_->ForwardBackward(bottom_vec); + // accumulate the loss and gradient + Dtype loss = 0; + for (int i = 0; i < param_.iter_size(); ++i) { + loss += net_->ForwardBackward(bottom_vec); + } + loss /= param_.iter_size(); + // average the loss across iterations for smoothed reporting if (losses.size() < average_loss) { losses.push_back(loss); int size = losses.size(); From 539f8798233e25ac8110f545995f5b8f7340718f Mon Sep 17 00:00:00 2001 From: Jonathan L Long Date: Tue, 30 Dec 2014 22:52:07 -0800 Subject: [PATCH 2/8] zero-init param diffs in gradient checker --- include/caffe/test/test_gradient_check_util.hpp | 7 +++++-- src/caffe/solver.cpp | 2 +- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/include/caffe/test/test_gradient_check_util.hpp b/include/caffe/test/test_gradient_check_util.hpp index 22937711b58..cc5dcbad0ee 100644 --- a/include/caffe/test/test_gradient_check_util.hpp +++ b/include/caffe/test/test_gradient_check_util.hpp @@ -80,11 +80,14 @@ void GradientChecker::CheckGradientSingle(Layer* layer, CHECK_EQ(top_count, bottom[blob_id]->count()); } } - // First, figure out what blobs we need to check against. + // First, figure out what blobs we need to check against, and zero init + // parameter blobs. vector*> blobs_to_check; vector propagate_down(bottom.size(), check_bottom < 0); for (int i = 0; i < layer->blobs().size(); ++i) { - blobs_to_check.push_back(layer->blobs()[i].get()); + Blob* blob = layer->blobs()[i].get(); + caffe_set(blob->count(), static_cast(0), blob->mutable_cpu_diff()); + blobs_to_check.push_back(blob); } if (check_bottom < 0) { for (int i = 0; i < bottom.size(); ++i) { diff --git a/src/caffe/solver.cpp b/src/caffe/solver.cpp index ad041b8f268..d104522002b 100644 --- a/src/caffe/solver.cpp +++ b/src/caffe/solver.cpp @@ -171,7 +171,7 @@ void Solver::Step(int iters) { // zero-init the params for (int i = 0; i < net_->params().size(); ++i) { shared_ptr > blob = net_->params()[i]; - switch(Caffe::mode()) { + switch (Caffe::mode()) { case Caffe::CPU: caffe_set(blob->count(), static_cast(0), blob->mutable_cpu_diff()); From 3262e464b06f1ecd8a04db1487cc5878c0cfd852 Mon Sep 17 00:00:00 2001 From: Sergio Date: Fri, 26 Sep 2014 23:03:26 -0700 Subject: [PATCH 3/8] accumulate gradients in inner product layer --- src/caffe/layers/inner_product_layer.cpp | 4 ++-- src/caffe/layers/inner_product_layer.cu | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/caffe/layers/inner_product_layer.cpp b/src/caffe/layers/inner_product_layer.cpp index 89e0c8fbad7..83c3235eb71 100644 --- a/src/caffe/layers/inner_product_layer.cpp +++ b/src/caffe/layers/inner_product_layer.cpp @@ -101,13 +101,13 @@ void InnerProductLayer::Backward_cpu(const vector*>& top, const Dtype* bottom_data = bottom[0]->cpu_data(); // Gradient with respect to weight caffe_cpu_gemm(CblasTrans, CblasNoTrans, N_, K_, M_, (Dtype)1., - top_diff, bottom_data, (Dtype)0., this->blobs_[0]->mutable_cpu_diff()); + top_diff, bottom_data, (Dtype)1., this->blobs_[0]->mutable_cpu_diff()); } if (bias_term_ && this->param_propagate_down_[1]) { const Dtype* top_diff = top[0]->cpu_diff(); // Gradient with respect to bias caffe_cpu_gemv(CblasTrans, M_, N_, (Dtype)1., top_diff, - bias_multiplier_.cpu_data(), (Dtype)0., + bias_multiplier_.cpu_data(), (Dtype)1., this->blobs_[1]->mutable_cpu_diff()); } if (propagate_down[0]) { diff --git a/src/caffe/layers/inner_product_layer.cu b/src/caffe/layers/inner_product_layer.cu index a9e1784a205..dd90cac12a8 100644 --- a/src/caffe/layers/inner_product_layer.cu +++ b/src/caffe/layers/inner_product_layer.cu @@ -33,13 +33,13 @@ void InnerProductLayer::Backward_gpu(const vector*>& top, const Dtype* bottom_data = bottom[0]->gpu_data(); // Gradient with respect to weight caffe_gpu_gemm(CblasTrans, CblasNoTrans, N_, K_, M_, (Dtype)1., - top_diff, bottom_data, (Dtype)0., this->blobs_[0]->mutable_gpu_diff()); + top_diff, bottom_data, (Dtype)1., this->blobs_[0]->mutable_gpu_diff()); } if (bias_term_ && this->param_propagate_down_[1]) { const Dtype* top_diff = top[0]->gpu_diff(); // Gradient with respect to bias caffe_gpu_gemv(CblasTrans, M_, N_, (Dtype)1., top_diff, - bias_multiplier_.gpu_data(), (Dtype)0., + bias_multiplier_.gpu_data(), (Dtype)1., this->blobs_[1]->mutable_gpu_diff()); } if (propagate_down[0]) { From 8cc9af01941cd8c7c32c664672e5f31106d2cc40 Mon Sep 17 00:00:00 2001 From: Jonathan L Long Date: Tue, 30 Dec 2014 22:29:35 -0800 Subject: [PATCH 4/8] accumulate gradients in (de)conv layers --- src/caffe/layers/conv_layer.cpp | 7 ------- src/caffe/layers/conv_layer.cu | 7 ------- src/caffe/layers/deconv_layer.cpp | 7 ------- src/caffe/layers/deconv_layer.cu | 7 ------- 4 files changed, 28 deletions(-) diff --git a/src/caffe/layers/conv_layer.cpp b/src/caffe/layers/conv_layer.cpp index c0c9f6f3371..928ef5ee468 100644 --- a/src/caffe/layers/conv_layer.cpp +++ b/src/caffe/layers/conv_layer.cpp @@ -39,13 +39,6 @@ void ConvolutionLayer::Backward_cpu(const vector*>& top, const vector& propagate_down, const vector*>& bottom) { const Dtype* weight = this->blobs_[0]->cpu_data(); Dtype* weight_diff = this->blobs_[0]->mutable_cpu_diff(); - if (this->param_propagate_down_[0]) { - caffe_set(this->blobs_[0]->count(), Dtype(0), weight_diff); - } - if (this->bias_term_ && this->param_propagate_down_[1]) { - caffe_set(this->blobs_[1]->count(), Dtype(0), - this->blobs_[1]->mutable_cpu_diff()); - } for (int i = 0; i < top.size(); ++i) { const Dtype* top_diff = top[i]->cpu_diff(); const Dtype* bottom_data = bottom[i]->cpu_data(); diff --git a/src/caffe/layers/conv_layer.cu b/src/caffe/layers/conv_layer.cu index 3902fdf3930..b8a98ff7cc9 100644 --- a/src/caffe/layers/conv_layer.cu +++ b/src/caffe/layers/conv_layer.cu @@ -31,13 +31,6 @@ void ConvolutionLayer::Backward_gpu(const vector*>& top, const vector& propagate_down, const vector*>& bottom) { const Dtype* weight = this->blobs_[0]->gpu_data(); Dtype* weight_diff = this->blobs_[0]->mutable_gpu_diff(); - if (this->param_propagate_down_[0]) { - caffe_gpu_set(this->blobs_[0]->count(), Dtype(0), weight_diff); - } - if (this->bias_term_ && this->param_propagate_down_[1]) { - caffe_gpu_set(this->blobs_[1]->count(), Dtype(0), - this->blobs_[1]->mutable_gpu_diff()); - } for (int i = 0; i < top.size(); ++i) { const Dtype* top_diff = top[i]->gpu_diff(); // Bias gradient, if necessary. diff --git a/src/caffe/layers/deconv_layer.cpp b/src/caffe/layers/deconv_layer.cpp index e6d65ab526b..a4612963b6b 100644 --- a/src/caffe/layers/deconv_layer.cpp +++ b/src/caffe/layers/deconv_layer.cpp @@ -39,13 +39,6 @@ void DeconvolutionLayer::Backward_cpu(const vector*>& top, const vector& propagate_down, const vector*>& bottom) { const Dtype* weight = this->blobs_[0]->cpu_data(); Dtype* weight_diff = this->blobs_[0]->mutable_cpu_diff(); - if (this->param_propagate_down_[0]) { - caffe_set(this->blobs_[0]->count(), Dtype(0), weight_diff); - } - if (this->bias_term_ && this->param_propagate_down_[1]) { - caffe_set(this->blobs_[1]->count(), Dtype(0), - this->blobs_[1]->mutable_cpu_diff()); - } for (int i = 0; i < top.size(); ++i) { const Dtype* top_diff = top[i]->cpu_diff(); const Dtype* bottom_data = bottom[i]->cpu_data(); diff --git a/src/caffe/layers/deconv_layer.cu b/src/caffe/layers/deconv_layer.cu index 9198dd64c72..39bc4de8c66 100644 --- a/src/caffe/layers/deconv_layer.cu +++ b/src/caffe/layers/deconv_layer.cu @@ -31,13 +31,6 @@ void DeconvolutionLayer::Backward_gpu(const vector*>& top, const vector& propagate_down, const vector*>& bottom) { const Dtype* weight = this->blobs_[0]->gpu_data(); Dtype* weight_diff = this->blobs_[0]->mutable_gpu_diff(); - if (this->param_propagate_down_[0]) { - caffe_gpu_set(this->blobs_[0]->count(), Dtype(0), weight_diff); - } - if (this->bias_term_ && this->param_propagate_down_[1]) { - caffe_gpu_set(this->blobs_[1]->count(), Dtype(0), - this->blobs_[1]->mutable_gpu_diff()); - } for (int i = 0; i < top.size(); ++i) { const Dtype* top_diff = top[i]->gpu_diff(); const Dtype* bottom_data = bottom[i]->gpu_data(); From 67b1ff3114320188b5046ff899d4fb0f87fd7b63 Mon Sep 17 00:00:00 2001 From: Jonathan L Long Date: Sat, 13 Sep 2014 17:41:59 -0700 Subject: [PATCH 5/8] accumulate gradients in cudnn conv layer --- src/caffe/layers/cudnn_conv_layer.cu | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/caffe/layers/cudnn_conv_layer.cu b/src/caffe/layers/cudnn_conv_layer.cu index 4a1a4c4f4f2..b4e802e13d1 100644 --- a/src/caffe/layers/cudnn_conv_layer.cu +++ b/src/caffe/layers/cudnn_conv_layer.cu @@ -101,12 +101,10 @@ void CuDNNConvolutionLayer::Backward_gpu(const vector*>& top, if (this->param_propagate_down_[0]) { weight = this->blobs_[0]->gpu_data(); weight_diff = this->blobs_[0]->mutable_gpu_diff(); - caffe_gpu_set(this->blobs_[0]->count(), Dtype(0), weight_diff); } Dtype* bias_diff = NULL; if (this->bias_term_ && this->param_propagate_down_[1]) { bias_diff = this->blobs_[1]->mutable_gpu_diff(); - caffe_gpu_set(this->blobs_[1]->count(), Dtype(0), bias_diff); } for (int i = 0; i < top.size(); ++i) { const Dtype* top_diff = top[i]->gpu_diff(); From 55585f5bfab61328a61125b3d49627a69022d817 Mon Sep 17 00:00:00 2001 From: Evan Shelhamer Date: Thu, 21 May 2015 17:06:42 -0700 Subject: [PATCH 6/8] adjust local learning rate and decay according to gradient accumulation Divide local rate by `iter_size` to normalize the gradient according to the full minibatch size and not only the computational batch size. Multiply the local decay by `iter_size` to counter the division of the local learning rate since the decay is multiplied by the rate in the update equation. --- src/caffe/solver.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/caffe/solver.cpp b/src/caffe/solver.cpp index d104522002b..4c8fa25c955 100644 --- a/src/caffe/solver.cpp +++ b/src/caffe/solver.cpp @@ -488,7 +488,7 @@ void SGDSolver::ApplyUpdate() { ClipGradients(); for (int param_id = 0; param_id < this->net_->params().size(); ++param_id) { Regularize(param_id); - ComputeUpdateValue(param_id, rate); + ComputeUpdateValue(param_id, rate / this->param_.iter_size()); } this->net_->Update(); } @@ -500,7 +500,8 @@ void SGDSolver::Regularize(int param_id) { this->net_->params_weight_decay(); Dtype weight_decay = this->param_.weight_decay(); string regularization_type = this->param_.regularization_type(); - Dtype local_decay = weight_decay * net_params_weight_decay[param_id]; + Dtype local_decay = weight_decay * net_params_weight_decay[param_id] + * this->param_.iter_size(); switch (Caffe::mode()) { case Caffe::CPU: { if (local_decay) { From 92ab737adad6d686ac75cdf934472f6a97b52fe7 Mon Sep 17 00:00:00 2001 From: Evan Shelhamer Date: Thu, 21 May 2015 18:14:16 -0700 Subject: [PATCH 7/8] test equivalence of solving with accumulating gradients Compare the parameters after solving with a given batch size and the halved batch size + two iter accumulation of gradients equivalent. Note: the test net dummy data layer now makes constant data and random gaussian targets. This assures the standard and gradient accumulation cases check the same data. Otherwise the difference in batch sizes causes different orders of random number draws. --- src/caffe/test/test_gradient_based_solver.cpp | 82 ++++++++++++++++++- 1 file changed, 79 insertions(+), 3 deletions(-) diff --git a/src/caffe/test/test_gradient_based_solver.cpp b/src/caffe/test/test_gradient_based_solver.cpp index eb2569c04f2..c9135d64e70 100644 --- a/src/caffe/test/test_gradient_based_solver.cpp +++ b/src/caffe/test/test_gradient_based_solver.cpp @@ -23,7 +23,7 @@ class GradientBasedSolverTest : public MultiDeviceTest { protected: GradientBasedSolverTest() : - seed_(1701), num_(5), channels_(3), height_(10), width_(10) {} + seed_(1701), num_(4), channels_(3), height_(10), width_(10) {} shared_ptr > solver_; int seed_; @@ -56,19 +56,21 @@ class GradientBasedSolverTest : public MultiDeviceTest { } void RunLeastSquaresSolver(const Dtype learning_rate, - const Dtype weight_decay, const Dtype momentum, const int num_iters) { + const Dtype weight_decay, const Dtype momentum, const int num_iters, + const int iter_size = 1) { ostringstream proto; proto << "max_iter: " << num_iters << " " "base_lr: " << learning_rate << " " "lr_policy: 'fixed' " + "iter_size: " << iter_size << " " "net_param { " " name: 'TestNetwork' " " layer { " " name: 'data' " " type: 'DummyData' " " dummy_data_param { " - " num: " << num_ << " " + " num: " << num_ / iter_size << " " " channels: " << channels_ << " " " height: " << height_ << " " " width: " << width_ << " " @@ -76,6 +78,10 @@ class GradientBasedSolverTest : public MultiDeviceTest { " height: 1 " " width: 1 " " data_filler { " + " type: 'constant' " + " value: 1.0 " + " } " + " data_filler { " " type: 'gaussian' " " std: 1.0 " " } " @@ -270,6 +276,45 @@ class GradientBasedSolverTest : public MultiDeviceTest { } } + void CheckAccumulation(const Dtype kLearningRate, const Dtype kWeightDecay, + const Dtype kMomentum, const int kNumIters, const int kIterSize) { + const double kPrecision = 1e-2; + const double kMinPrecision = 1e-7; + // Solve without accumulation and save parameters. + this->RunLeastSquaresSolver(kLearningRate, kWeightDecay, kMomentum, + kNumIters); + // Save parameters for comparison. + Net& net = *this->solver_->net(); + const vector > >& param_blobs = + net.layer_by_name("innerprod")->blobs(); + vector > > noaccum_params(param_blobs.size()); + for (int i = 0; i < param_blobs.size(); ++i) { + noaccum_params[i].reset(new Blob()); + noaccum_params[i]->CopyFrom(*param_blobs[i], false, true); + } + // Solve by equivalent accumulation of gradients over divided batches. + this->RunLeastSquaresSolver(kLearningRate, kWeightDecay, kMomentum, + kNumIters, kIterSize); + Net& net_accum = *this->solver_->net(); + const vector > >& accum_params = + net_accum.layer_by_name("innerprod")->blobs(); + // Compare accumulated parameters against no accumulation standard. + const int D = this->channels_ * this->height_ * this->width_; + for (int i = 0; i < D; ++i) { + const Dtype expected_param = noaccum_params[0]->cpu_data()[i]; + const Dtype accum_param = accum_params[0]->cpu_data()[i]; + const Dtype error_margin = std::max(kMinPrecision, kPrecision * + std::min(fabs(expected_param), fabs(accum_param))); + EXPECT_NEAR(expected_param, accum_param, error_margin); + } + ASSERT_EQ(1, accum_params[1]->count()); + const Dtype expected_bias = noaccum_params[1]->cpu_data()[0]; + const Dtype accum_bias = accum_params[1]->cpu_data()[0]; + const Dtype error_margin = std::max(kMinPrecision, kPrecision * + std::min(fabs(expected_bias), fabs(accum_bias))); + EXPECT_NEAR(expected_bias, accum_bias, error_margin); + } + // Test that the correct update is computed for a regularized least squares // problem: // @@ -372,6 +417,16 @@ TYPED_TEST(SGDSolverTest, TestLeastSquaresUpdateWithEverything) { } } +TYPED_TEST(SGDSolverTest, TestLeastSquaresUpdateWithEverythingAccum) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 0.01; + const Dtype kWeightDecay = 0.1; + const Dtype kMomentum = 0.9; + const int kNumIters = 4; + const int kIterSize = 2; + this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters, + kIterSize); +} template class AdaGradSolverTest : public GradientBasedSolverTest { @@ -416,6 +471,16 @@ TYPED_TEST(AdaGradSolverTest, TestAdaGradLeastSquaresUpdateWithEverything) { } } +TYPED_TEST(AdaGradSolverTest, TestLeastSquaresUpdateWithEverythingAccum) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 0.01; + const Dtype kWeightDecay = 0.1; + const Dtype kMomentum = 0.0; + const int kNumIters = 4; + const int kIterSize = 2; + this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters, + kIterSize); +} template class NesterovSolverTest : public GradientBasedSolverTest { @@ -482,4 +547,15 @@ TYPED_TEST(NesterovSolverTest, TestNesterovLeastSquaresUpdateWithEverything) { } } +TYPED_TEST(NesterovSolverTest, TestLeastSquaresUpdateWithEverythingAccum) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 0.01; + const Dtype kWeightDecay = 0.1; + const Dtype kMomentum = 0.9; + const int kNumIters = 4; + const int kIterSize = 2; + this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters, + kIterSize); +} + } // namespace caffe From 0e7a0785db224aa7cf2bd925d8b7910bdc3f7a98 Mon Sep 17 00:00:00 2001 From: Evan Shelhamer Date: Thu, 28 May 2015 12:43:29 -0700 Subject: [PATCH 8/8] directly normalize accumulated gradients `SGDSolver::Normalize()` normalizes accumulated gradients by scaling inversely to the accumulation as `1 / iter_size`. This fixes accumulation for AdaGrad and is more obvious than fooling with rates and decays in 55585f5. --- include/caffe/solver.hpp | 1 + src/caffe/solver.cpp | 32 +++++++++++++++++++++++++++++--- 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/include/caffe/solver.hpp b/include/caffe/solver.hpp index da1bab13663..c2ced487d6f 100644 --- a/include/caffe/solver.hpp +++ b/include/caffe/solver.hpp @@ -81,6 +81,7 @@ class SGDSolver : public Solver { void PreSolve(); Dtype GetLearningRate(); virtual void ApplyUpdate(); + virtual void Normalize(int param_id); virtual void Regularize(int param_id); virtual void ComputeUpdateValue(int param_id, Dtype rate); virtual void ClipGradients(); diff --git a/src/caffe/solver.cpp b/src/caffe/solver.cpp index 4c8fa25c955..aabe0edec80 100644 --- a/src/caffe/solver.cpp +++ b/src/caffe/solver.cpp @@ -487,12 +487,39 @@ void SGDSolver::ApplyUpdate() { } ClipGradients(); for (int param_id = 0; param_id < this->net_->params().size(); ++param_id) { + Normalize(param_id); Regularize(param_id); - ComputeUpdateValue(param_id, rate / this->param_.iter_size()); + ComputeUpdateValue(param_id, rate); } this->net_->Update(); } +template +void SGDSolver::Normalize(int param_id) { + if (this->param_.iter_size() == 1) { return; } + // Scale gradient to counterbalance accumulation. + const vector > >& net_params = this->net_->params(); + const Dtype accum_normalization = Dtype(1.) / this->param_.iter_size(); + switch (Caffe::mode()) { + case Caffe::CPU: { + caffe_scal(net_params[param_id]->count(), accum_normalization, + net_params[param_id]->mutable_cpu_diff()); + break; + } + case Caffe::GPU: { +#ifndef CPU_ONLY + caffe_gpu_scal(net_params[param_id]->count(), accum_normalization, + net_params[param_id]->mutable_gpu_diff()); +#else + NO_GPU; +#endif + break; + } + default: + LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode(); + } +} + template void SGDSolver::Regularize(int param_id) { const vector > >& net_params = this->net_->params(); @@ -500,8 +527,7 @@ void SGDSolver::Regularize(int param_id) { this->net_->params_weight_decay(); Dtype weight_decay = this->param_.weight_decay(); string regularization_type = this->param_.regularization_type(); - Dtype local_decay = weight_decay * net_params_weight_decay[param_id] - * this->param_.iter_size(); + Dtype local_decay = weight_decay * net_params_weight_decay[param_id]; switch (Caffe::mode()) { case Caffe::CPU: { if (local_decay) {