Skip to content

Commit

Permalink
Merge pull request BVLC#6 from cbfinn/rnn_split_layer
Browse files Browse the repository at this point in the history
Add option to set # hidden units in RNN layer.
  • Loading branch information
justinjfu committed May 22, 2015
2 parents 60fe9c1 + cadea68 commit 8d27f32
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 3 deletions.
2 changes: 2 additions & 0 deletions include/caffe/sequence_layers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ class RecurrentLayer : public Layer<Dtype> {
virtual inline int MinBottomBlobs() const { return 2; }
virtual inline int MaxBottomBlobs() const { return 3; }
virtual inline int ExactNumTopBlobs() const { return 1; }
virtual inline shared_ptr<Net<Dtype> > UnrolledNet() const { return unrolled_net_; }


virtual inline bool AllowForceBackward(const int bottom_index) const {
// Can't propagate to sequence continuation indicators.
Expand Down
17 changes: 14 additions & 3 deletions src/caffe/layers/rnn_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,13 @@ void RNNLayer<Dtype>::OutputBlobNames(vector<string>* names) const {
template <typename Dtype>
void RNNLayer<Dtype>::FillUnrolledNet(NetParameter* net_param) const {
const int num_output = this->layer_param_.recurrent_param().num_output();

// Set number of hidden units. 0 (default) means borrow from num_output
int num_hidden = this->layer_param_.rnn_param().num_hidden();
if(num_hidden == 0){
num_hidden = num_output;
}

CHECK_GT(num_output, 0) << "num_output must be positive";
const FillerParameter& weight_filler =
this->layer_param_.recurrent_param().weight_filler();
Expand All @@ -41,7 +48,7 @@ void RNNLayer<Dtype>::FillUnrolledNet(NetParameter* net_param) const {
// use to save redundant code.
LayerParameter hidden_param;
hidden_param.set_type("InnerProduct");
hidden_param.mutable_inner_product_param()->set_num_output(num_output);
hidden_param.mutable_inner_product_param()->set_num_output(num_hidden);
hidden_param.mutable_inner_product_param()->set_bias_term(false);
hidden_param.mutable_inner_product_param()->set_axis(2);
hidden_param.mutable_inner_product_param()->
Expand All @@ -52,6 +59,10 @@ void RNNLayer<Dtype>::FillUnrolledNet(NetParameter* net_param) const {
biased_hidden_param.mutable_inner_product_param()->
mutable_bias_filler()->CopyFrom(bias_filler);

// Parameters of output layer inner product. o_t = g(W*h_t*b_o).
LayerParameter biased_output_param(biased_hidden_param);
biased_output_param.mutable_inner_product_param()->set_num_output(num_output);

LayerParameter sum_param;
sum_param.set_type("Eltwise");
sum_param.mutable_eltwise_param()->set_operation(
Expand All @@ -70,7 +81,7 @@ void RNNLayer<Dtype>::FillUnrolledNet(NetParameter* net_param) const {
BlobShape input_shape;
input_shape.add_dim(1); // h_0 is a single timestep
input_shape.add_dim(this->N_);
input_shape.add_dim(num_output);
input_shape.add_dim(num_hidden);
net_param->add_input("h_0");
net_param->add_input_shape()->CopyFrom(input_shape);

Expand Down Expand Up @@ -189,7 +200,7 @@ void RNNLayer<Dtype>::FillUnrolledNet(NetParameter* net_param) const {
// W_ho_h_t := W_ho * h_t + b_o
{
LayerParameter* w_param = net_param->add_layer();
w_param->CopyFrom(biased_hidden_param);
w_param->CopyFrom(biased_output_param);
w_param->set_name("W_ho_h_" + ts);
w_param->add_param()->set_name("W_ho");
w_param->add_param()->set_name("b_o");
Expand Down
2 changes: 2 additions & 0 deletions src/caffe/proto/caffe.proto
Original file line number Diff line number Diff line change
Expand Up @@ -728,6 +728,8 @@ message RNNParameter {
optional string output_nonlinearity = 1 [default = "TanH"];
// Nonlinearity on recurrent state (h_t)
optional string recurrent_nonlinearity = 2 [default = "TanH"];
// Number of hidden units. 0 (default) means use num_output from RecurrentParameter
optional uint32 num_hidden = 3 [default = 0];
}

// Message that stores parameters used by ReLULayer
Expand Down
40 changes: 40 additions & 0 deletions src/caffe/test/test_rnn_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,4 +193,44 @@ TYPED_TEST(RNNLayerTest, TestGradientNonZeroFlushBufferSize2) {
this->blob_top_vec_, 0);
}

TYPED_TEST(RNNLayerTest, TestNumHiddenShape) {
// Check number of parameters with num_hidden setting.

typedef typename TypeParam::Dtype Dtype;
const int nout = 11;
const int nhid = 7;
const int nin = 6;

LayerParameter new_params(this->layer_param_);
new_params.mutable_rnn_param()->set_num_hidden(nhid);
new_params.mutable_recurrent_param()->set_num_output(nout);
RNNLayer<Dtype> layer(new_params);

layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
vector<int> expected_top_shape = this->blob_bottom_.shape();
expected_top_shape.resize(3);
expected_top_shape[2] = nout;
EXPECT_TRUE(this->blob_top_.shape() == expected_top_shape);

vector<shared_ptr<Blob<Dtype> > > params = layer.UnrolledNet()->params();
int total_count = 0;
for(typename vector<shared_ptr<Blob<Dtype> > >::iterator it=params.begin(); it!=params.end(); ++it){
shared_ptr<Blob<Dtype> > blob = (*it);
total_count += blob->count();
}
const int expected_count = (nin*nhid+nhid*nhid+nhid)+(nhid*nout+nout);
EXPECT_EQ(total_count, expected_count);
}

TYPED_TEST(RNNLayerTest, TestGradientWithNumHidden) {
typedef typename TypeParam::Dtype Dtype;
LayerParameter new_params(this->layer_param_);
new_params.mutable_rnn_param()->set_num_hidden(30);
RNNLayer<Dtype> layer(new_params);

GradientChecker<Dtype> checker(1e-2, 1e-3);
checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_,
this->blob_top_vec_, 0);
}

} // namespace caffe

0 comments on commit 8d27f32

Please sign in to comment.