diff --git a/src/caffe/solver.cpp b/src/caffe/solver.cpp index 2b1896101d3..d4322d37cbf 100644 --- a/src/caffe/solver.cpp +++ b/src/caffe/solver.cpp @@ -636,7 +636,8 @@ void SGDSolver::ComputeUpdateValue() { // Compute the value to history, and then copy them to the blob's diff. Dtype local_rate = rate * net_params_lr[param_id] / this->param_.iter_size(); - Dtype local_decay = weight_decay * net_params_weight_decay[param_id]; + Dtype local_decay = weight_decay * net_params_weight_decay[param_id] + * this->param_.iter_size(); if (local_decay) { if (regularization_type == "L2") { @@ -673,7 +674,8 @@ void SGDSolver::ComputeUpdateValue() { // Compute the value to history, and then copy them to the blob's diff. Dtype local_rate = rate * net_params_lr[param_id] / this->param_.iter_size(); - Dtype local_decay = weight_decay * net_params_weight_decay[param_id]; + Dtype local_decay = weight_decay * net_params_weight_decay[param_id] + * this->param_.iter_size(); if (local_decay) { if (regularization_type == "L2") {