Skip to content

Commit

Permalink
net.cpp now allows zero-sized batches
Browse files Browse the repository at this point in the history
  • Loading branch information
mtamburrano committed May 18, 2015
1 parent 352aef4 commit 40f4f6c
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 4 deletions.
13 changes: 13 additions & 0 deletions include/caffe/net.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,13 @@ class Net {
*/
string Forward(const string& input_blob_protos, Dtype* loss = NULL);

/**
* If a bottom blob has num == 0, then the forward is not allowed and a top
* blob of the layer must be reshaped with num = 0, so all the subsequent
* layers will not have their forward allowed, too.
*/
bool ForwardIsAllowed(int i);

/**
* The network backward should take no input and output, since it solely
* computes the gradient w.r.t the parameters, and the data has already been
Expand All @@ -67,6 +74,12 @@ class Net {
void BackwardFrom(int start);
void BackwardTo(int end);

/**
* If a top blob has num == 0, then the forward on this layer was been
* denied, so we don't need to backpropagate
*/
bool BackwardIsAllowed(int i);

/**
* @brief Reshape all layers from bottom to top.
*
Expand Down
41 changes: 37 additions & 4 deletions src/caffe/net.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -505,9 +505,12 @@ Dtype Net<Dtype>::ForwardFromTo(int start, int end) {
}
for (int i = start; i <= end; ++i) {
// LOG(ERROR) << "Forwarding " << layer_names_[i];
Dtype layer_loss = layers_[i]->Forward(bottom_vecs_[i], top_vecs_[i]);
loss += layer_loss;
if (debug_info_) { ForwardDebugInfo(i); }
if (ForwardIsAllowed(i)) {
layers_[i]->Reshape(bottom_vecs_[i], top_vecs_[i]);
Dtype layer_loss = layers_[i]->Forward(bottom_vecs_[i], top_vecs_[i]);
loss += layer_loss;
if (debug_info_) { ForwardDebugInfo(i); }
}
}
return loss;
}
Expand Down Expand Up @@ -568,7 +571,7 @@ void Net<Dtype>::BackwardFromTo(int start, int end) {
CHECK_GE(end, 0);
CHECK_LT(start, layers_.size());
for (int i = start; i >= end; --i) {
if (layer_need_backward_[i]) {
if (layer_need_backward_[i] && BackwardIsAllowed(i)) {
layers_[i]->Backward(
top_vecs_[i], bottom_need_backward_[i], bottom_vecs_[i]);
if (debug_info_) { BackwardDebugInfo(i); }
Expand Down Expand Up @@ -718,6 +721,36 @@ void Net<Dtype>::Reshape() {
}
}

template <typename Dtype>
bool Net<Dtype>::ForwardIsAllowed(int i) {
bool forward_allowed = true;
for (int b = 0; forward_allowed == true && b < bottom_vecs_[i].size(); ++b) {
if (bottom_vecs_[i][b]->num() == 0) {
// if a bottom has num == 0, deny the forward and reshape a top
// to num = 0 to deny the forward of subsequent layers
if (top_vecs_[i].size() > 0) {
top_vecs_[i][0]->Reshape(0,
top_vecs_[i][0]->channels(),
top_vecs_[i][0]->height(),
top_vecs_[i][0]->width());
}
forward_allowed = false;
}
}
return forward_allowed;
}

template <typename Dtype>
bool Net<Dtype>::BackwardIsAllowed(int i) {
bool backward_allowed = true;
for (int t = 0; backward_allowed == true && t < top_vecs_[i].size(); ++t) {
if (top_vecs_[i][t]->num() == 0) {
backward_allowed = false;
}
}
return backward_allowed;
}

template <typename Dtype>
void Net<Dtype>::CopyTrainedLayersFrom(const NetParameter& param) {
int num_source_layers = param.layer_size();
Expand Down

0 comments on commit 40f4f6c

Please sign in to comment.