diff --git a/include/caffe/net.hpp b/include/caffe/net.hpp index 23da5dbbf2d..413c9f77d84 100644 --- a/include/caffe/net.hpp +++ b/include/caffe/net.hpp @@ -84,6 +84,13 @@ class Net { /// @brief Updates the network weights based on the diff values computed. void Update(); + /** + * @brief Shares weight data of owner blobs with shared blobs. + * + * Note: this is called by Net::Init, and thus should normally not be + * called manually. + */ + void ShareWeightData(); /** * @brief For an already initialized net, implicitly copies (i.e., using no @@ -148,6 +155,9 @@ class Net { return param_names_index_; } inline const vector& param_owners() const { return param_owners_; } + inline const vector& param_display_names() const { + return param_display_names_; + } /// @brief Input and output blob numbers inline int num_inputs() const { return net_input_blobs_.size(); } inline int num_outputs() const { return net_output_blobs_.size(); } diff --git a/src/caffe/net.cpp b/src/caffe/net.cpp index 83c417f3c34..5ba609120eb 100644 --- a/src/caffe/net.cpp +++ b/src/caffe/net.cpp @@ -206,6 +206,7 @@ void Net::Init(const NetParameter& in_param) { layer_names_index_[layer_names_[layer_id]] = layer_id; } GetLearningRateAndWeightDecay(); + ShareWeightData(); debug_info_ = param.debug_info(); LOG(INFO) << "Network initialization done."; LOG(INFO) << "Memory required for data: " << memory_used_ * sizeof(Dtype); @@ -444,8 +445,6 @@ void Net::AppendParam(const NetParameter& param, const int layer_id, // Strict dimension checking -- all dims must be the same. CHECK(this_blob->shape() == owner_blob->shape()); } - layers_[layer_id]->blobs()[param_id]->ShareData( - *layers_[owner_layer_id]->blobs()[owner_param_id]); } } @@ -749,35 +748,7 @@ void Net::ToProto(NetParameter* param, bool write_diff) const { template void Net::Update() { - // First, accumulate the diffs of any shared parameters into their owner's - // diff. (Assumes that the learning rate, weight decay, etc. have already been - // accounted for in the current diff.) - for (int i = 0; i < params_.size(); ++i) { - if (param_owners_[i] < 0) { continue; } - if (debug_info_) { UpdateDebugInfo(i); } - const int count = params_[i]->count(); - const Dtype* this_diff; - Dtype* owner_diff; - switch (Caffe::mode()) { - case Caffe::CPU: - this_diff = params_[i]->cpu_diff(); - owner_diff = params_[param_owners_[i]]->mutable_cpu_diff(); - caffe_add(count, this_diff, owner_diff, owner_diff); - break; -#ifndef CPU_ONLY - case Caffe::GPU: - this_diff = params_[i]->gpu_diff(); - owner_diff = params_[param_owners_[i]]->mutable_gpu_diff(); - caffe_gpu_add(count, this_diff, owner_diff, owner_diff); - break; -#else - NO_GPU; -#endif - default: - LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode(); - } - } - // Now, update the owned parameters. + // Update only the owned parameters. for (int i = 0; i < params_.size(); ++i) { if (param_owners_[i] >= 0) { continue; } if (debug_info_) { UpdateDebugInfo(i); } @@ -785,6 +756,15 @@ void Net::Update() { } } +template +void Net::ShareWeightData() { + for (int i = 0; i < params_.size(); ++i) { + if (param_owners_[i] < 0) { continue; } + params_[i]->ShareData(*params_[param_owners_[i]]); + params_[i]->ShareDiff(*params_[param_owners_[i]]); + } +} + template bool Net::has_blob(const string& blob_name) const { return blob_names_index_.find(blob_name) != blob_names_index_.end();