diff --git a/include/caffe/common_layers.hpp b/include/caffe/common_layers.hpp index c0389e745a9..fc8f2c48898 100644 --- a/include/caffe/common_layers.hpp +++ b/include/caffe/common_layers.hpp @@ -186,6 +186,44 @@ class EltwiseLayer : public Layer { bool stable_prod_grad_; }; +/** + * @brief A layer for learning "embeddings" of one-hot vector input. + * Equivalent to an InnerProductLayer with one-hot vectors as input, but + * for efficiency the input is the "hot" index of each column itself. + * + * TODO(dox): thorough documentation for Forward, Backward, and proto params. + */ +template +class EmbedLayer : public Layer { + public: + explicit EmbedLayer(const LayerParameter& param) + : Layer(param) {} + virtual void LayerSetUp(const vector*>& bottom, + const vector*>& top); + virtual void Reshape(const vector*>& bottom, + const vector*>& top); + + virtual inline const char* type() const { return "Embed"; } + virtual inline int ExactNumBottomBlobs() const { return 1; } + virtual inline int ExactNumTopBlobs() const { return 1; } + + protected: + virtual void Forward_cpu(const vector*>& bottom, + const vector*>& top); + virtual void Forward_gpu(const vector*>& bottom, + const vector*>& top); + virtual void Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + virtual void Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + + int M_; + int K_; + int N_; + bool bias_term_; + Blob bias_multiplier_; +}; + /** * @brief Takes two+ Blobs, interprets last Blob as a selector and * filter remaining Blobs accordingly with selector data (0 means that @@ -624,6 +662,35 @@ class SliceLayer : public Layer { vector slice_point_; }; +/** + * @brief Copy a Blob along specified dimensions. + */ +template +class TileLayer : public Layer { + public: + explicit TileLayer(const LayerParameter& param) + : Layer(param) {} + virtual void Reshape(const vector*>& bottom, + const vector*>& top); + + virtual inline const char* type() const { return "Tile"; } + virtual inline int ExactNumBottomBlobs() const { return 1; } + virtual inline int ExactNumTopBlobs() const { return 1; } + + protected: + virtual void Forward_cpu(const vector*>& bottom, + const vector*>& top); + virtual void Forward_gpu(const vector*>& bottom, + const vector*>& top); + + virtual void Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + virtual void Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + + unsigned int axis_, tiles_, outer_dim_, inner_dim_; +}; + } // namespace caffe #endif // CAFFE_COMMON_LAYERS_HPP_ diff --git a/include/caffe/test/test_gradient_check_util.hpp b/include/caffe/test/test_gradient_check_util.hpp index cc5dcbad0ee..25f35d1589e 100644 --- a/include/caffe/test/test_gradient_check_util.hpp +++ b/include/caffe/test/test_gradient_check_util.hpp @@ -45,6 +45,10 @@ class GradientChecker { void CheckGradientEltwise(Layer* layer, const vector*>& bottom, const vector*>& top); + // Checks the gradient of a single output with respect to particular input + // blob(s). If check_bottom = i >= 0, check only the ith bottom Blob. + // If check_bottom == -1, check everything -- all bottom Blobs and all + // param Blobs. Otherwise (if check_bottom < -1), check only param Blobs. void CheckGradientSingle(Layer* layer, const vector*>& bottom, const vector*>& top, int check_bottom, int top_id, int top_data_id, bool element_wise = false); @@ -83,21 +87,22 @@ void GradientChecker::CheckGradientSingle(Layer* layer, // First, figure out what blobs we need to check against, and zero init // parameter blobs. vector*> blobs_to_check; - vector propagate_down(bottom.size(), check_bottom < 0); + vector propagate_down(bottom.size(), check_bottom == -1); for (int i = 0; i < layer->blobs().size(); ++i) { Blob* blob = layer->blobs()[i].get(); caffe_set(blob->count(), static_cast(0), blob->mutable_cpu_diff()); blobs_to_check.push_back(blob); } - if (check_bottom < 0) { + if (check_bottom == -1) { for (int i = 0; i < bottom.size(); ++i) { blobs_to_check.push_back(bottom[i]); } - } else { + } else if (check_bottom >= 0) { CHECK_LT(check_bottom, bottom.size()); blobs_to_check.push_back(bottom[check_bottom]); propagate_down[check_bottom] = true; } + CHECK_GT(blobs_to_check.size(), 0) << "No blobs to check."; // Compute the gradient analytically using Backward Caffe::set_random_seed(seed_); // Ignore the loss from the layer (it's just the weighted sum of the losses diff --git a/include/caffe/util/gpu_util.cuh b/include/caffe/util/gpu_util.cuh new file mode 100644 index 00000000000..994202f2a1a --- /dev/null +++ b/include/caffe/util/gpu_util.cuh @@ -0,0 +1,35 @@ +#ifndef CAFFE_UTIL_GPU_UTIL_H_ +#define CAFFE_UTIL_GPU_UTIL_H_ + +namespace caffe { + +template +inline __device__ Dtype caffe_gpu_atomic_add(const Dtype val, Dtype* address); + +template <> +inline __device__ +float caffe_gpu_atomic_add(const float val, float* address) { + return atomicAdd(address, val); +} + +// double atomicAdd implementation taken from: +// http://docs.nvidia.com/cuda/cuda-c-programming-guide/#axzz3PVCpVsEG +template <> +inline __device__ +double caffe_gpu_atomic_add(const double val, double* address) { + unsigned long long int* address_as_ull = // NOLINT(runtime/int) + // NOLINT_NEXT_LINE(runtime/int) + reinterpret_cast(address); + unsigned long long int old = *address_as_ull; // NOLINT(runtime/int) + unsigned long long int assumed; // NOLINT(runtime/int) + do { + assumed = old; + old = atomicCAS(address_as_ull, assumed, + __double_as_longlong(val + __longlong_as_double(assumed))); + } while (assumed != old); + return __longlong_as_double(old); +} + +} // namespace caffe + +#endif // CAFFE_UTIL_GPU_UTIL_H_ diff --git a/src/caffe/layers/concat_layer.cpp b/src/caffe/layers/concat_layer.cpp index 1cac8fc3387..95fba105b9a 100644 --- a/src/caffe/layers/concat_layer.cpp +++ b/src/caffe/layers/concat_layer.cpp @@ -76,13 +76,14 @@ void ConcatLayer::Backward_cpu(const vector*>& top, int offset_concat_axis = 0; const int top_concat_axis = top[0]->shape(concat_axis_); for (int i = 0; i < bottom.size(); ++i) { - if (!propagate_down[i]) { continue; } - Dtype* bottom_diff = bottom[i]->mutable_cpu_diff(); const int bottom_concat_axis = bottom[i]->shape(concat_axis_); - for (int n = 0; n < num_concats_; ++n) { - caffe_copy(bottom_concat_axis * concat_input_size_, top_diff + - (n * top_concat_axis + offset_concat_axis) * concat_input_size_, - bottom_diff + n * bottom_concat_axis * concat_input_size_); + if (propagate_down[i]) { + Dtype* bottom_diff = bottom[i]->mutable_cpu_diff(); + for (int n = 0; n < num_concats_; ++n) { + caffe_copy(bottom_concat_axis * concat_input_size_, top_diff + + (n * top_concat_axis + offset_concat_axis) * concat_input_size_, + bottom_diff + n * bottom_concat_axis * concat_input_size_); + } } offset_concat_axis += bottom_concat_axis; } diff --git a/src/caffe/layers/concat_layer.cu b/src/caffe/layers/concat_layer.cu index 8f2e85d8f52..3c64c7ef224 100644 --- a/src/caffe/layers/concat_layer.cu +++ b/src/caffe/layers/concat_layer.cu @@ -53,15 +53,16 @@ void ConcatLayer::Backward_gpu(const vector*>& top, const int top_concat_axis = top[0]->shape(concat_axis_); const bool kForward = false; for (int i = 0; i < bottom.size(); ++i) { - if (!propagate_down[i]) { continue; } - Dtype* bottom_diff = bottom[i]->mutable_gpu_diff(); const int bottom_concat_axis = bottom[i]->shape(concat_axis_); - const int bottom_concat_size = bottom_concat_axis * concat_input_size_; - const int nthreads = bottom_concat_size * num_concats_; - Concat // NOLINT_NEXT_LINE(whitespace/operators) - <<>>( - nthreads, top_diff, kForward, num_concats_, concat_input_size_, - top_concat_axis, bottom_concat_axis, offset_concat_axis, bottom_diff); + if (propagate_down[i]) { + Dtype* bottom_diff = bottom[i]->mutable_gpu_diff(); + const int bottom_concat_size = bottom_concat_axis * concat_input_size_; + const int nthreads = bottom_concat_size * num_concats_; + Concat // NOLINT_NEXT_LINE(whitespace/operators) + <<>>( + nthreads, top_diff, kForward, num_concats_, concat_input_size_, + top_concat_axis, bottom_concat_axis, offset_concat_axis, bottom_diff); + } offset_concat_axis += bottom_concat_axis; } } diff --git a/src/caffe/layers/embed_layer.cpp b/src/caffe/layers/embed_layer.cpp new file mode 100644 index 00000000000..be6b2cd2727 --- /dev/null +++ b/src/caffe/layers/embed_layer.cpp @@ -0,0 +1,122 @@ +#include + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/common_layers.hpp" +#include "caffe/filler.hpp" +#include "caffe/layer.hpp" +#include "caffe/util/math_functions.hpp" + +namespace caffe { + +template +void EmbedLayer::LayerSetUp(const vector*>& bottom, + const vector*>& top) { + N_ = this->layer_param_.embed_param().num_output(); + CHECK_GT(N_, 0) << "EmbedLayer num_output must be positive."; + K_ = this->layer_param_.embed_param().input_dim(); + CHECK_GT(K_, 0) << "EmbedLayer input_dim must be positive."; + bias_term_ = this->layer_param_.embed_param().bias_term(); + // Check if we need to set up the weights + if (this->blobs_.size() > 0) { + LOG(INFO) << "Skipping parameter initialization"; + } else { + if (bias_term_) { + this->blobs_.resize(2); + } else { + this->blobs_.resize(1); + } + // Initialize the weights -- + // transposed from InnerProductLayer for spatial locality. + vector weight_shape(2); + weight_shape[0] = K_; + weight_shape[1] = N_; + this->blobs_[0].reset(new Blob(weight_shape)); + // fill the weights + shared_ptr > weight_filler(GetFiller( + this->layer_param_.embed_param().weight_filler())); + weight_filler->Fill(this->blobs_[0].get()); + // If necessary, initialize and fill the bias term + if (bias_term_) { + vector bias_shape(1, N_); + this->blobs_[1].reset(new Blob(bias_shape)); + shared_ptr > bias_filler(GetFiller( + this->layer_param_.embed_param().bias_filler())); + bias_filler->Fill(this->blobs_[1].get()); + } + } // parameter initialization + this->param_propagate_down_.resize(this->blobs_.size(), true); +} + +template +void EmbedLayer::Reshape(const vector*>& bottom, + const vector*>& top) { + // Figure out the dimensions + M_ = bottom[0]->count(); + vector top_shape = bottom[0]->shape(); + top_shape.push_back(N_); + top[0]->Reshape(top_shape); + // Set up the bias multiplier + if (bias_term_) { + vector bias_shape(1, M_); + bias_multiplier_.Reshape(bias_shape); + caffe_set(M_, Dtype(1), bias_multiplier_.mutable_cpu_data()); + } +} + +template +void EmbedLayer::Forward_cpu(const vector*>& bottom, + const vector*>& top) { + const Dtype* bottom_data = bottom[0]->cpu_data(); + const Dtype* weight = this->blobs_[0]->cpu_data(); + Dtype* top_data = top[0]->mutable_cpu_data(); + int index; + for (int n = 0; n < M_; ++n) { + index = static_cast(bottom_data[n]); + DCHECK_GE(index, 0); + DCHECK_LT(index, K_); + DCHECK_EQ(static_cast(index), bottom_data[n]) << "non-integer input"; + caffe_copy(N_, weight + index * N_, top_data + n * N_); + } + if (bias_term_) { + const Dtype* bias = this->blobs_[1]->cpu_data(); + caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, M_, N_, 1, Dtype(1), + bias_multiplier_.cpu_data(), bias, Dtype(1), top_data); + } +} + +template +void EmbedLayer::Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom) { + CHECK(!propagate_down[0]) << "Can't backpropagate to EmbedLayer input."; + if (this->param_propagate_down_[0]) { + const Dtype* top_diff = top[0]->cpu_diff(); + const Dtype* bottom_data = bottom[0]->cpu_data(); + // Gradient with respect to weight + Dtype* weight_diff = this->blobs_[0]->mutable_cpu_diff(); + int index; + for (int n = 0; n < M_; ++n) { + index = static_cast(bottom_data[n]); + DCHECK_GE(index, 0); + DCHECK_LT(index, K_); + DCHECK_EQ(static_cast(index), bottom_data[n]) + << "non-integer input"; + caffe_axpy(N_, Dtype(1), top_diff + n * N_, weight_diff + index * N_); + } + } + if (bias_term_ && this->param_propagate_down_[1]) { + const Dtype* top_diff = top[0]->cpu_diff(); + Dtype* bias_diff = this->blobs_[1]->mutable_cpu_diff(); + caffe_cpu_gemv(CblasTrans, M_, N_, Dtype(1), top_diff, + bias_multiplier_.cpu_data(), Dtype(1), bias_diff); + } +} + +#ifdef CPU_ONLY +STUB_GPU(EmbedLayer); +#endif + +INSTANTIATE_CLASS(EmbedLayer); +REGISTER_LAYER_CLASS(Embed); + +} // namespace caffe diff --git a/src/caffe/layers/embed_layer.cu b/src/caffe/layers/embed_layer.cu new file mode 100644 index 00000000000..672fb9c608c --- /dev/null +++ b/src/caffe/layers/embed_layer.cu @@ -0,0 +1,85 @@ +#include + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/common_layers.hpp" +#include "caffe/filler.hpp" +#include "caffe/layer.hpp" +#include "caffe/util/gpu_util.cuh" +#include "caffe/util/math_functions.hpp" + +namespace caffe { + +template +__global__ void EmbedForward(const int nthreads, const Dtype* bottom_data, + const Dtype* weight, const int M, const int N, const int K, + Dtype* top_data) { + CUDA_KERNEL_LOOP(top_index, nthreads) { + const int n = top_index / N; + const int d = top_index % N; + const int index = static_cast(bottom_data[n]); + const int weight_index = index * N + d; + top_data[top_index] = weight[weight_index]; + } +} + +template +__global__ void EmbedBackward(const int nthreads, const Dtype* bottom_data, + const Dtype* top_diff, const int M, const int N, const int K, + Dtype* weight_diff); + +template +__global__ void EmbedBackward(const int nthreads, const Dtype* bottom_data, + const Dtype* top_diff, const int M, const int N, const int K, + Dtype* weight_diff) { + CUDA_KERNEL_LOOP(top_index, nthreads) { + const int n = top_index / N; + const int d = top_index % N; + const int index = static_cast(bottom_data[n]); + const int weight_index = index * N + d; + caffe_gpu_atomic_add(top_diff[top_index], weight_diff + weight_index); + } +} + +template +void EmbedLayer::Forward_gpu(const vector*>& bottom, + const vector*>& top) { + const Dtype* bottom_data = bottom[0]->gpu_data(); + Dtype* top_data = top[0]->mutable_gpu_data(); + const Dtype* weight = this->blobs_[0]->gpu_data(); + const int count = top[0]->count(); + EmbedForward // NOLINT_NEXT_LINE(whitespace/operators) + <<>>( + count, bottom_data, weight, M_, N_, K_, top_data); + if (bias_term_) { + caffe_gpu_gemm(CblasNoTrans, CblasNoTrans, M_, N_, 1, Dtype(1), + bias_multiplier_.gpu_data(), + this->blobs_[1]->gpu_data(), Dtype(1), top_data); + } +} + +template +void EmbedLayer::Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom) { + CHECK(!propagate_down[0]) << "Can't backpropagate to EmbedLayer input."; + if (this->param_propagate_down_[0]) { + const int top_count = top[0]->count(); + const int count = this->blobs_[0]->count(); + const Dtype* top_diff = top[0]->gpu_diff(); + const Dtype* bottom_data = bottom[0]->gpu_data(); + Dtype* weight_diff = this->blobs_[0]->mutable_gpu_diff(); + EmbedBackward // NOLINT_NEXT_LINE(whitespace/operators) + <<>>( + top_count, bottom_data, top_diff, M_, N_, K_, weight_diff); + } + if (bias_term_ && this->param_propagate_down_[1]) { + const Dtype* top_diff = top[0]->gpu_diff(); + Dtype* bias_diff = this->blobs_[1]->mutable_gpu_diff(); + caffe_gpu_gemv(CblasTrans, M_, N_, Dtype(1), top_diff, + bias_multiplier_.gpu_data(), Dtype(1), bias_diff); + } +} + +INSTANTIATE_LAYER_GPU_FUNCS(EmbedLayer); + +} // namespace caffe diff --git a/src/caffe/layers/tile_layer.cpp b/src/caffe/layers/tile_layer.cpp new file mode 100644 index 00000000000..f55008cc53a --- /dev/null +++ b/src/caffe/layers/tile_layer.cpp @@ -0,0 +1,62 @@ +#include + +#include "caffe/common_layers.hpp" +#include "caffe/layer.hpp" +#include "caffe/util/math_functions.hpp" + +namespace caffe { + +template +void TileLayer::Reshape( + const vector*>& bottom, const vector*>& top) { + const TileParameter& tile_param = this->layer_param_.tile_param(); + axis_ = bottom[0]->CanonicalAxisIndex(tile_param.axis()); + CHECK(tile_param.has_tiles()) << "Number of tiles must be specified"; + tiles_ = tile_param.tiles(); + CHECK_GT(tiles_, 0) << "Number of tiles must be positive."; + vector top_shape = bottom[0]->shape(); + top_shape[axis_] = bottom[0]->shape(axis_) * tiles_; + top[0]->Reshape(top_shape); + outer_dim_ = bottom[0]->count(0, axis_); + inner_dim_ = bottom[0]->count(axis_); +} + +template +void TileLayer::Forward_cpu( + const vector*>& bottom, const vector*>& top) { + const Dtype* bottom_data = bottom[0]->cpu_data(); + Dtype* top_data = top[0]->mutable_cpu_data(); + for (int i = 0; i < outer_dim_; ++i) { + for (int t = 0; t < tiles_; ++t) { + caffe_copy(inner_dim_, bottom_data, top_data); + top_data += inner_dim_; + } + bottom_data += inner_dim_; + } +} + +template +void TileLayer::Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom) { + if (!propagate_down[0]) { return; } + const Dtype* top_diff = top[0]->cpu_diff(); + Dtype* bottom_diff = bottom[0]->mutable_cpu_diff(); + for (int i = 0; i < outer_dim_; ++i) { + caffe_copy(inner_dim_, top_diff, bottom_diff); + top_diff += inner_dim_; + for (int t = 1; t < tiles_; ++t) { + caffe_axpy(inner_dim_, Dtype(1), top_diff, bottom_diff); + top_diff += inner_dim_; + } + bottom_diff += inner_dim_; + } +} + +#ifdef CPU_ONLY +STUB_GPU(TileLayer); +#endif + +INSTANTIATE_CLASS(TileLayer); +REGISTER_LAYER_CLASS(Tile); + +} // namespace caffe diff --git a/src/caffe/layers/tile_layer.cu b/src/caffe/layers/tile_layer.cu new file mode 100644 index 00000000000..7fd3bc47d0f --- /dev/null +++ b/src/caffe/layers/tile_layer.cu @@ -0,0 +1,67 @@ +#include + +#include "caffe/common_layers.hpp" +#include "caffe/layer.hpp" +#include "caffe/util/math_functions.hpp" + +namespace caffe { + +template +__global__ void Tile(const int nthreads, const Dtype* bottom_data, + const int tile_size, const int num_tiles, const int bottom_tile_axis, + Dtype* top_data) { + CUDA_KERNEL_LOOP(index, nthreads) { + const int d = index % tile_size; + const int b = (index / tile_size / num_tiles) % bottom_tile_axis; + const int n = index / tile_size / num_tiles / bottom_tile_axis; + const int bottom_index = (n * bottom_tile_axis + b) * tile_size + d; + top_data[index] = bottom_data[bottom_index]; + } +} + +template +void TileLayer::Forward_gpu( + const vector*>& bottom, const vector*>& top) { + const Dtype* bottom_data = bottom[0]->gpu_data(); + Dtype* top_data = top[0]->mutable_gpu_data(); + const int bottom_tile_axis = bottom[0]->shape(axis_); + const int nthreads = top[0]->count(); + Tile // NOLINT_NEXT_LINE(whitespace/operators) + <<>>( + nthreads, bottom_data, inner_dim_, tiles_, bottom_tile_axis, top_data); +} + +template +__global__ void TileBackward(const int nthreads, const Dtype* top_diff, + const int tile_size, const int num_tiles, const int bottom_tile_axis, + Dtype* bottom_diff) { + CUDA_KERNEL_LOOP(index, nthreads) { + const int d = index % tile_size; + const int b = (index / tile_size) % bottom_tile_axis; + const int n = index / tile_size / bottom_tile_axis; + bottom_diff[index] = 0; + int top_index = (n * num_tiles * bottom_tile_axis + b) * tile_size + d; + for (int t = 0; t < num_tiles; ++t) { + bottom_diff[index] += top_diff[top_index]; + top_index += bottom_tile_axis * tile_size; + } + } +} + +template +void TileLayer::Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom) { + if (!propagate_down[0]) { return; } + const Dtype* top_diff = top[0]->gpu_diff(); + Dtype* bottom_diff = bottom[0]->mutable_gpu_diff(); + const int bottom_tile_axis = bottom[0]->shape(axis_); + const int tile_size = inner_dim_ / bottom_tile_axis; + const int nthreads = bottom[0]->count(); + TileBackward // NOLINT_NEXT_LINE(whitespace/operators) + <<>>( + nthreads, top_diff, tile_size, tiles_, bottom_tile_axis, bottom_diff); +} + +INSTANTIATE_LAYER_GPU_FUNCS(TileLayer); + +} // namespace caffe diff --git a/src/caffe/net.cpp b/src/caffe/net.cpp index 1a0b0404bb6..11368f2f78b 100644 --- a/src/caffe/net.cpp +++ b/src/caffe/net.cpp @@ -429,7 +429,8 @@ void Net::AppendTop(const NetParameter& param, const int layer_id, blob_name_to_idx->find(blob_name) != blob_name_to_idx->end()) { // If we are not doing in-place computation but have duplicated blobs, // raise an error. - LOG(FATAL) << "Duplicate blobs produced by multiple sources."; + LOG(FATAL) << "Top blob '" << blob_name + << "' produced by multiple sources."; } else { // Normal output. if (Caffe::root_solver()) { @@ -473,8 +474,8 @@ int Net::AppendBottom(const NetParameter& param, const int layer_id, const LayerParameter& layer_param = param.layer(layer_id); const string& blob_name = layer_param.bottom(bottom_id); if (available_blobs->find(blob_name) == available_blobs->end()) { - LOG(FATAL) << "Unknown blob input " << blob_name - << " (at index " << bottom_id << ") to layer " << layer_id; + LOG(FATAL) << "Unknown bottom blob '" << blob_name << "' (layer '" + << layer_param.name() << "', bottom index " << bottom_id << ")"; } const int blob_id = (*blob_name_to_idx)[blob_name]; if (Caffe::root_solver()) { @@ -550,10 +551,19 @@ void Net::AppendParam(const NetParameter& param, const int layer_id, ParamSpec_DimCheckMode_PERMISSIVE)) { // Permissive dimension checking -- only check counts are the same. CHECK_EQ(this_blob->count(), owner_blob->count()) - << "Shared parameter blobs must have the same count."; + << "Cannot share param '" << param_name << "' owned by layer '" + << layer_names_[owner_layer_id] << "' with layer '" + << layer_names_[layer_id] << "'; count mismatch. Owner layer param " + << "shape is " << owner_blob->shape_string() << "; sharing layer " + << "shape is " << this_blob->shape_string(); } else { // Strict dimension checking -- all dims must be the same. - CHECK(this_blob->shape() == owner_blob->shape()); + CHECK(this_blob->shape() == owner_blob->shape()) + << "Cannot share param '" << param_name << "' owned by layer '" + << layer_names_[owner_layer_id] << "' with layer '" + << layer_names_[layer_id] << "'; shape mismatch. Owner layer param " + << "shape is " << owner_blob->shape_string() << "; sharing layer " + << "expects shape " << this_blob->shape_string(); } const int learnable_param_id = learnable_param_ids_[owner_net_param_id]; learnable_param_ids_.push_back(learnable_param_id); @@ -780,7 +790,11 @@ void Net::ShareTrainedLayersWith(const Net* other) { << "Incompatible number of blobs for layer " << source_layer_name; for (int j = 0; j < target_blobs.size(); ++j) { Blob* source_blob = source_layer->blobs()[j].get(); - CHECK(target_blobs[j]->shape() == source_blob->shape()); + CHECK(target_blobs[j]->shape() == source_blob->shape()) + << "Cannot share param " << j << " weights from layer '" + << source_layer_name << "'; shape mismatch. Source param shape is " + << source_blob->shape_string() << "; target param shape is " + << target_blobs[j]->shape_string(); target_blobs[j]->ShareData(*source_blob); } } @@ -844,6 +858,17 @@ void Net::CopyTrainedLayersFrom(const NetParameter& param) { CHECK_EQ(target_blobs.size(), source_layer.blobs_size()) << "Incompatible number of blobs for layer " << source_layer_name; for (int j = 0; j < target_blobs.size(); ++j) { + if (!target_blobs[j]->ShapeEquals(source_layer.blobs(j))) { + Blob source_blob; + const bool kReshape = true; + source_blob.FromProto(source_layer.blobs(j), kReshape); + LOG(FATAL) << "Cannot copy param " << j << " weights from layer '" + << source_layer_name << "'; shape mismatch. Source param shape is " + << source_blob.shape_string() << "; target param shape is " + << target_blobs[j]->shape_string() << ". " + << "To learn this layer's parameters from scratch rather than " + << "copying from a saved net, rename the layer."; + } const bool kReshape = false; target_blobs[j]->FromProto(source_layer.blobs(j), kReshape); } diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto index d4c97d2bd06..aa299f8660b 100644 --- a/src/caffe/proto/caffe.proto +++ b/src/caffe/proto/caffe.proto @@ -301,7 +301,7 @@ message ParamSpec { // NOTE // Update the next available ID when you add a new LayerParameter field. // -// LayerParameter next available layer-specific ID: 137 (last added: reduction_param) +// LayerParameter next available layer-specific ID: 139 (last added: tile_param) message LayerParameter { optional string name = 1; // the layer name optional string type = 2; // the layer type @@ -357,6 +357,7 @@ message LayerParameter { optional DropoutParameter dropout_param = 108; optional DummyDataParameter dummy_data_param = 109; optional EltwiseParameter eltwise_param = 110; + optional EmbedParameter embed_param = 137; optional ExpParameter exp_param = 111; optional FlattenParameter flatten_param = 135; optional HDF5DataParameter hdf5_data_param = 112; @@ -382,6 +383,7 @@ message LayerParameter { optional SliceParameter slice_param = 126; optional TanHParameter tanh_param = 127; optional ThresholdParameter threshold_param = 128; + optional TileParameter tile_param = 138; optional WindowDataParameter window_data_param = 129; } @@ -562,6 +564,21 @@ message EltwiseParameter { optional bool stable_prod_grad = 3 [default = true]; } +// Message that stores parameters used by EmbedLayer +message EmbedParameter { + optional uint32 num_output = 1; // The number of outputs for the layer + // The input is given as integers to be interpreted as one-hot + // vector indices with dimension num_input. Hence num_input should be + // 1 greater than the maximum possible input value. + optional uint32 input_dim = 2; + + optional bool bias_term = 3 [default = true]; // Whether to use a bias term + optional FillerParameter weight_filler = 4; // The filler for the weight + optional FillerParameter bias_filler = 5; // The filler for the bias + +} + +// Message that stores parameters used by ExpLayer message ExpParameter { // ExpLayer computes outputs y = base ^ (shift + scale * x), for base > 0. // Or if base is set to the default (-1), base is set to e, @@ -903,6 +920,16 @@ message TanHParameter { optional Engine engine = 1 [default = DEFAULT]; } +// Message that stores parameters used by TileLayer +message TileParameter { + // The index of the axis to tile. + optional int32 axis = 1 [default = 1]; + + // The number of copies (tiles) of the blob to output. + optional int32 tiles = 2; +} + +// Message that stores parameters used by ThresholdLayer message ThresholdParameter { optional float threshold = 1 [default = 0]; // Strictly positive values } diff --git a/src/caffe/test/test_concat_layer.cpp b/src/caffe/test/test_concat_layer.cpp index 662a50fa23b..088e0a41685 100644 --- a/src/caffe/test/test_concat_layer.cpp +++ b/src/caffe/test/test_concat_layer.cpp @@ -173,4 +173,13 @@ TYPED_TEST(ConcatLayerTest, TestGradientChannels) { this->blob_top_vec_); } +TYPED_TEST(ConcatLayerTest, TestGradientChannelsBottomOneOnly) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + ConcatLayer layer(layer_param); + GradientChecker checker(1e-2, 1e-2); + checker.CheckGradient(&layer, this->blob_bottom_vec_0_, + this->blob_top_vec_, 1); +} + } // namespace caffe diff --git a/src/caffe/test/test_embed_layer.cpp b/src/caffe/test/test_embed_layer.cpp new file mode 100644 index 00000000000..7a4fb9800f2 --- /dev/null +++ b/src/caffe/test/test_embed_layer.cpp @@ -0,0 +1,183 @@ +#include +#include + +#include "gtest/gtest.h" + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/filler.hpp" +#include "caffe/vision_layers.hpp" + +#include "caffe/test/test_caffe_main.hpp" +#include "caffe/test/test_gradient_check_util.hpp" + +namespace caffe { + +#ifndef CPU_ONLY +extern cudaDeviceProp CAFFE_TEST_CUDA_PROP; +#endif + +template +class EmbedLayerTest : public MultiDeviceTest { + typedef typename TypeParam::Dtype Dtype; + protected: + EmbedLayerTest() + : blob_bottom_(new Blob(4, 1, 1, 1)), + blob_top_(new Blob()) { + // fill the values + FillerParameter filler_param; + UniformFiller filler(filler_param); + filler.Fill(this->blob_bottom_); + blob_bottom_vec_.push_back(blob_bottom_); + blob_top_vec_.push_back(blob_top_); + } + virtual ~EmbedLayerTest() { delete blob_bottom_; delete blob_top_; } + Blob* const blob_bottom_; + Blob* const blob_top_; + vector*> blob_bottom_vec_; + vector*> blob_top_vec_; +}; + +TYPED_TEST_CASE(EmbedLayerTest, TestDtypesAndDevices); + +TYPED_TEST(EmbedLayerTest, TestSetUp) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + EmbedParameter* embed_param = layer_param.mutable_embed_param(); + embed_param->set_num_output(10); + embed_param->set_input_dim(5); + shared_ptr > layer(new EmbedLayer(layer_param)); + layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + ASSERT_EQ(this->blob_top_->num_axes(), 5); + EXPECT_EQ(this->blob_top_->shape(0), 4); + EXPECT_EQ(this->blob_top_->shape(1), 1); + EXPECT_EQ(this->blob_top_->shape(2), 1); + EXPECT_EQ(this->blob_top_->shape(3), 1); + EXPECT_EQ(this->blob_top_->shape(4), 10); +} + +TYPED_TEST(EmbedLayerTest, TestForward) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + EmbedParameter* embed_param = layer_param.mutable_embed_param(); + const int kNumOutput = 10; + const int kInputDim = 5; + embed_param->set_num_output(kNumOutput); + embed_param->set_input_dim(kInputDim); + embed_param->mutable_weight_filler()->set_type("uniform"); + embed_param->mutable_weight_filler()->set_min(-10); + embed_param->mutable_weight_filler()->set_max(10); + embed_param->set_bias_term(false); + shared_ptr > layer(new EmbedLayer(layer_param)); + layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + ASSERT_EQ(1, layer->blobs().size()); + vector weight_shape(2); + weight_shape[0] = kInputDim; + weight_shape[1] = kNumOutput; + ASSERT_TRUE(weight_shape == layer->blobs()[0]->shape()); + for (int i = 0; i < this->blob_bottom_->count(); ++i) { + this->blob_bottom_->mutable_cpu_data()[i] = caffe_rng_rand() % kInputDim; + } + layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_); + vector weight_offset(2, 0); + vector top_offset(5, 0); + for (int i = 0; i < this->blob_bottom_->count(); ++i) { + weight_offset[0] = static_cast(this->blob_bottom_->cpu_data()[i]); + weight_offset[1] = 0; + top_offset[0] = i; + top_offset[4] = 0; + for (int j = 0; j < kNumOutput; ++j) { + EXPECT_EQ(layer->blobs()[0]->data_at(weight_offset), + this->blob_top_->data_at(top_offset)); + ++top_offset[4]; + ++weight_offset[1]; + } + } +} + +TYPED_TEST(EmbedLayerTest, TestForwardWithBias) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + EmbedParameter* embed_param = layer_param.mutable_embed_param(); + const int kNumOutput = 10; + const int kInputDim = 5; + embed_param->set_num_output(kNumOutput); + embed_param->set_input_dim(kInputDim); + embed_param->mutable_weight_filler()->set_type("uniform"); + embed_param->mutable_weight_filler()->set_min(-10); + embed_param->mutable_weight_filler()->set_max(10); + embed_param->mutable_bias_filler()->CopyFrom(embed_param->weight_filler()); + embed_param->set_bias_term(true); + shared_ptr > layer(new EmbedLayer(layer_param)); + layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + ASSERT_EQ(2, layer->blobs().size()); + vector weight_shape(2); + weight_shape[0] = kInputDim; + weight_shape[1] = kNumOutput; + ASSERT_TRUE(weight_shape == layer->blobs()[0]->shape()); + for (int i = 0; i < this->blob_bottom_->count(); ++i) { + this->blob_bottom_->mutable_cpu_data()[i] = caffe_rng_rand() % kInputDim; + } + layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_); + vector bias_offset(1, 0); + vector weight_offset(2, 0); + vector top_offset(5, 0); + for (int i = 0; i < this->blob_bottom_->count(); ++i) { + weight_offset[0] = static_cast(this->blob_bottom_->cpu_data()[i]); + weight_offset[1] = 0; + top_offset[0] = i; + top_offset[4] = 0; + bias_offset[0] = 0; + for (int j = 0; j < kNumOutput; ++j) { + EXPECT_EQ(layer->blobs()[0]->data_at(weight_offset) + + layer->blobs()[1]->data_at(bias_offset), + this->blob_top_->data_at(top_offset)); + ++top_offset[4]; + ++weight_offset[1]; + ++bias_offset[0]; + } + } +} + +TYPED_TEST(EmbedLayerTest, TestGradient) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + EmbedParameter* embed_param = layer_param.mutable_embed_param(); + embed_param->set_num_output(10); + embed_param->set_input_dim(5); + embed_param->set_bias_term(false); + embed_param->mutable_weight_filler()->set_type("uniform"); + embed_param->mutable_weight_filler()->set_min(-10); + embed_param->mutable_weight_filler()->set_max(10); + EmbedLayer layer(layer_param); + GradientChecker checker(1e-2, 1e-3); + this->blob_bottom_->mutable_cpu_data()[0] = 4; + this->blob_bottom_->mutable_cpu_data()[1] = 2; + this->blob_bottom_->mutable_cpu_data()[2] = 2; + this->blob_bottom_->mutable_cpu_data()[3] = 3; + checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, + this->blob_top_vec_, -2); +} + +TYPED_TEST(EmbedLayerTest, TestGradientWithBias) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + EmbedParameter* embed_param = layer_param.mutable_embed_param(); + embed_param->set_num_output(10); + embed_param->set_input_dim(5); + embed_param->set_bias_term(true); + embed_param->mutable_weight_filler()->set_type("uniform"); + embed_param->mutable_weight_filler()->set_min(-10); + embed_param->mutable_weight_filler()->set_max(10); + embed_param->mutable_bias_filler()->CopyFrom(embed_param->weight_filler()); + EmbedLayer layer(layer_param); + GradientChecker checker(1e-2, 1e-3); + this->blob_bottom_->mutable_cpu_data()[0] = 4; + this->blob_bottom_->mutable_cpu_data()[1] = 2; + this->blob_bottom_->mutable_cpu_data()[2] = 2; + this->blob_bottom_->mutable_cpu_data()[3] = 3; + checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, + this->blob_top_vec_, -2); +} + +} // namespace caffe diff --git a/src/caffe/test/test_tile_layer.cpp b/src/caffe/test/test_tile_layer.cpp new file mode 100644 index 00000000000..540aac3c2d3 --- /dev/null +++ b/src/caffe/test/test_tile_layer.cpp @@ -0,0 +1,162 @@ +#include +#include + +#include "gtest/gtest.h" + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/filler.hpp" +#include "caffe/vision_layers.hpp" + +#include "caffe/test/test_caffe_main.hpp" +#include "caffe/test/test_gradient_check_util.hpp" + +namespace caffe { + +template +class TileLayerTest : public MultiDeviceTest { + typedef typename TypeParam::Dtype Dtype; + + protected: + TileLayerTest() + : blob_bottom_(new Blob(2, 3, 4, 5)), + blob_top_(new Blob()) {} + virtual void SetUp() { + blob_bottom_vec_.push_back(blob_bottom_); + blob_top_vec_.push_back(blob_top_); + FillerParameter filler_param; + filler_param.set_mean(0.0); + filler_param.set_std(1.0); + GaussianFiller filler(filler_param); + filler.Fill(blob_bottom_); + } + + virtual ~TileLayerTest() { + delete blob_bottom_; + delete blob_top_; + } + + Blob* const blob_bottom_; + Blob* const blob_top_; + vector*> blob_bottom_vec_; + vector*> blob_top_vec_; +}; + +TYPED_TEST_CASE(TileLayerTest, TestDtypesAndDevices); + +TYPED_TEST(TileLayerTest, TestTrivialSetup) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + const int kNumTiles = 1; + layer_param.mutable_tile_param()->set_tiles(kNumTiles); + for (int i = 0; i < this->blob_bottom_->num_axes(); ++i) { + layer_param.mutable_tile_param()->set_axis(i); + TileLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + ASSERT_EQ(this->blob_top_->num_axes(), this->blob_bottom_->num_axes()); + for (int j = 0; j < this->blob_bottom_->num_axes(); ++j) { + EXPECT_EQ(this->blob_top_->shape(j), this->blob_bottom_->shape(j)); + } + } +} + +TYPED_TEST(TileLayerTest, TestSetup) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + const int kNumTiles = 3; + layer_param.mutable_tile_param()->set_tiles(kNumTiles); + for (int i = 0; i < this->blob_bottom_->num_axes(); ++i) { + layer_param.mutable_tile_param()->set_axis(i); + TileLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + ASSERT_EQ(this->blob_top_->num_axes(), this->blob_bottom_->num_axes()); + for (int j = 0; j < this->blob_bottom_->num_axes(); ++j) { + const int top_dim = + ((i == j) ? kNumTiles : 1) * this->blob_bottom_->shape(j); + EXPECT_EQ(top_dim, this->blob_top_->shape(j)); + } + } +} + +TYPED_TEST(TileLayerTest, TestForwardNum) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + const int kTileAxis = 0; + const int kNumTiles = 3; + layer_param.mutable_tile_param()->set_axis(kTileAxis); + layer_param.mutable_tile_param()->set_tiles(kNumTiles); + TileLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + for (int n = 0; n < this->blob_top_->num(); ++n) { + for (int c = 0; c < this->blob_top_->channels(); ++c) { + for (int h = 0; h < this->blob_top_->height(); ++h) { + for (int w = 0; w < this->blob_top_->width(); ++w) { + const int bottom_n = n % this->blob_bottom_->num(); + EXPECT_EQ(this->blob_bottom_->data_at(bottom_n, c, h, w), + this->blob_top_->data_at(n, c, h, w)); + } + } + } + } +} + +TYPED_TEST(TileLayerTest, TestForwardChannels) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + const int kNumTiles = 3; + layer_param.mutable_tile_param()->set_tiles(kNumTiles); + TileLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + for (int n = 0; n < this->blob_top_->num(); ++n) { + for (int c = 0; c < this->blob_top_->channels(); ++c) { + for (int h = 0; h < this->blob_top_->height(); ++h) { + for (int w = 0; w < this->blob_top_->width(); ++w) { + const int bottom_c = c % this->blob_bottom_->channels(); + EXPECT_EQ(this->blob_bottom_->data_at(n, bottom_c, h, w), + this->blob_top_->data_at(n, c, h, w)); + } + } + } + } +} + +TYPED_TEST(TileLayerTest, TestTrivialGradient) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + const int kNumTiles = 1; + layer_param.mutable_tile_param()->set_tiles(kNumTiles); + TileLayer layer(layer_param); + GradientChecker checker(1e-2, 1e-2); + checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, + this->blob_top_vec_); +} + +TYPED_TEST(TileLayerTest, TestGradientNum) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + const int kTileAxis = 0; + const int kNumTiles = 3; + layer_param.mutable_tile_param()->set_axis(kTileAxis); + layer_param.mutable_tile_param()->set_tiles(kNumTiles); + TileLayer layer(layer_param); + GradientChecker checker(1e-2, 1e-2); + checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, + this->blob_top_vec_); +} + +TYPED_TEST(TileLayerTest, TestGradientChannels) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + const int kTileAxis = 1; + const int kNumTiles = 3; + layer_param.mutable_tile_param()->set_axis(kTileAxis); + layer_param.mutable_tile_param()->set_tiles(kNumTiles); + TileLayer layer(layer_param); + GradientChecker checker(1e-2, 1e-2); + checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, + this->blob_top_vec_); +} + +} // namespace caffe diff --git a/src/caffe/util/insert_splits.cpp b/src/caffe/util/insert_splits.cpp index 416f80ab3c2..475a2a9f618 100644 --- a/src/caffe/util/insert_splits.cpp +++ b/src/caffe/util/insert_splits.cpp @@ -32,7 +32,8 @@ void InsertSplits(const NetParameter& param, NetParameter* param_split) { const string& blob_name = layer_param.bottom(j); if (blob_name_to_last_top_idx.find(blob_name) == blob_name_to_last_top_idx.end()) { - LOG(FATAL) << "Unknown blob input " << blob_name << " to layer " << j; + LOG(FATAL) << "Unknown bottom blob '" << blob_name << "' (layer '" + << layer_param.name() << "', bottom index " << j << ")"; } const pair& bottom_idx = make_pair(i, j); const pair& top_idx = blob_name_to_last_top_idx[blob_name];