Skip to content

Commit

Permalink
Modular model definitions
Browse files Browse the repository at this point in the history
  • Loading branch information
cypof committed Oct 17, 2014
1 parent 1ed11d5 commit 2c9abcb
Show file tree
Hide file tree
Showing 11 changed files with 261 additions and 62 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ docs/dev

# Eclipse Project settings
*.*project
.settings

# CMake generated files
*.gen.cmake

# OSX files
.DS_Store
6 changes: 3 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ CONFIG_FILE := Makefile.config
include $(CONFIG_FILE)

BUILD_DIR_LINK := $(BUILD_DIR)
RELEASE_BUILD_DIR := .$(BUILD_DIR)_release
DEBUG_BUILD_DIR := .$(BUILD_DIR)_debug
RELEASE_BUILD_DIR ?= .$(BUILD_DIR)_release
DEBUG_BUILD_DIR ?= .$(BUILD_DIR)_debug

DEBUG ?= 0
ifeq ($(DEBUG), 1)
Expand Down Expand Up @@ -269,7 +269,7 @@ endif

# Debugging
ifeq ($(DEBUG), 1)
COMMON_FLAGS += -DDEBUG -g -O0
COMMON_FLAGS += -DDEBUG -g -O0 -DBOOST_NOINLINE='__attribute__ ((noinline))'
NVCCFLAGS += -G
else
COMMON_FLAGS += -DNDEBUG -O2
Expand Down
30 changes: 30 additions & 0 deletions examples/mnist/lenet_conv_pool.prototxt
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
layers {
name: "conv"
type: CONVOLUTION
bottom: "${bottom}"
top: "conv"
blobs_lr: 1
blobs_lr: 2
convolution_param {
num_output: ${num_output}
kernel_size: 5
stride: 1
weight_filler {
type: "xavier"
}
bias_filler {
type: "constant"
}
}
}
layers {
name: "pool"
type: POOLING
bottom: "conv"
top: "pool"
pooling_param {
pool: MAX
kernel_size: 2
stride: 2
}
}
69 changes: 13 additions & 56 deletions examples/mnist/lenet_train_test.prototxt
Original file line number Diff line number Diff line change
Expand Up @@ -29,71 +29,28 @@ layers {
}
include: { phase: TEST }
}

layers {
name: "conv1"
type: CONVOLUTION
bottom: "data"
top: "conv1"
blobs_lr: 1
blobs_lr: 2
convolution_param {
num_output: 20
kernel_size: 5
stride: 1
weight_filler {
type: "xavier"
}
bias_filler {
type: "constant"
}
}
}
layers {
name: "pool1"
type: POOLING
bottom: "conv1"
top: "pool1"
pooling_param {
pool: MAX
kernel_size: 2
stride: 2
}
}
layers {
name: "conv2"
type: CONVOLUTION
bottom: "pool1"
top: "conv2"
blobs_lr: 1
blobs_lr: 2
convolution_param {
num_output: 50
kernel_size: 5
stride: 1
weight_filler {
type: "xavier"
}
bias_filler {
type: "constant"
}
name: "cp1"
type: IMPORT
import_param {
net: "examples/mnist/lenet_conv_pool.prototxt"
var { name: "bottom" value: "/data" }
var { name: "num_output" value: "20" }
}
}
layers {
name: "pool2"
type: POOLING
bottom: "conv2"
top: "pool2"
pooling_param {
pool: MAX
kernel_size: 2
stride: 2
name: "cp2"
type: IMPORT
import_param {
net: "examples/mnist/lenet_conv_pool.prototxt"
var { name: "bottom" value: "../cp1/pool" }
var { name: "num_output" value: "50" }
}
}
layers {
name: "ip1"
type: INNER_PRODUCT
bottom: "pool2"
bottom: "cp2/pool"
top: "ip1"
blobs_lr: 1
blobs_lr: 2
Expand Down
7 changes: 7 additions & 0 deletions include/caffe/net.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,13 @@ class Net {
/// @brief Get misc parameters, e.g. the LR multiplier and weight decay.
void GetLearningRateAndWeightDecay();

// @brief Loads imports, for modular network definitions
static void LoadImports(const NetParameter& source, NetParameter* target);
static void LoadImports(const NetParameter& source, NetParameter* target,
const string& pwd);
// @brief Resolves a layer or blob name, e.g. "../data"
static string ResolveImportName(const string& path, const string& pwd);

/// @brief Individual layers in the net
vector<shared_ptr<Layer<Dtype> > > layers_;
vector<string> layer_names_;
Expand Down
2 changes: 2 additions & 0 deletions include/caffe/util/io.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ inline void WriteProtoToBinaryFile(
WriteProtoToBinaryFile(proto, filename.c_str());
}

string ReadFile(const string& filename);

bool ReadFileToDatum(const string& filename, const int label, Datum* datum);

inline bool ReadFileToDatum(const string& filename, Datum* datum) {
Expand Down
67 changes: 66 additions & 1 deletion src/caffe/net.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
#include <boost/algorithm/string.hpp>
#include <google/protobuf/text_format.h>

#include <algorithm>
#include <map>
#include <set>
Expand All @@ -15,6 +18,7 @@
#include "caffe/util/upgrade_proto.hpp"

#include "caffe/test/test_caffe_main.hpp"
using boost::replace_all;

namespace caffe {

Expand All @@ -32,10 +36,13 @@ Net<Dtype>::Net(const string& param_file) {

template <typename Dtype>
void Net<Dtype>::Init(const NetParameter& in_param) {
// Load import layers
NetParameter expanded(in_param);
LoadImports(in_param, &expanded);
// Filter layers based on their include/exclude rules and
// the current NetState.
NetParameter filtered_param;
FilterNet(in_param, &filtered_param);
FilterNet(expanded, &filtered_param);
LOG(INFO) << "Initializing net from parameters: " << std::endl
<< filtered_param.DebugString();
// Create a copy of filtered_param with splits added where necessary.
Expand Down Expand Up @@ -462,6 +469,64 @@ void Net<Dtype>::AppendParam(const NetParameter& param, const int layer_id,
}
}

template <typename Dtype>
void Net<Dtype>::LoadImports(const NetParameter& source, NetParameter* target) {
target->CopyFrom(source);
target->clear_layers();
LoadImports(source, target, "");
}

template <typename Dtype>
void Net<Dtype>::LoadImports(const NetParameter& source, NetParameter* target,
const string& pwd) {
for (int i = 0; i < source.layers_size(); ++i) {
if (source.layers(i).type() == LayerParameter_LayerType_IMPORT) {
const LayerParameter& layer = source.layers(i);
CHECK(layer.has_import_param()) << "Missing import_param";
const ImportParameter& import = layer.import_param();
string proto = ReadFile(import.net());
// Replace variables and references
for (int j = 0; j < import.var_size(); ++j) {
const Pair& p = import.var(j);
replace_all(proto, "${" + p.name() + "}", p.value());
}
NetParameter net;
bool parse = google::protobuf::TextFormat::ParseFromString(proto, &net);
CHECK(parse) << "Failed to parse NetParameter file: " << import.net();
CHECK(layer.has_name() && layer.name().length() > 0)
<< "Import layer must have a name";
LoadImports(net, target, ResolveImportName(layer.name(), pwd));
} else {
LayerParameter *t = target->add_layers();
t->CopyFrom(source.layers(i));
t->set_name(ResolveImportName(t->name(), pwd));
for (int j = 0; j < source.layers(i).top_size(); ++j)
t->set_top(j, ResolveImportName(source.layers(i).top(j), pwd));
for (int j = 0; j < source.layers(i).bottom_size(); ++j)
t->set_bottom(j, ResolveImportName(source.layers(i).bottom(j), pwd));
}
}
}

template <typename Dtype>
string Net<Dtype>::ResolveImportName(const string& path, const string& pwd) {
CHECK(!boost::starts_with(pwd, "/") && !boost::ends_with(pwd, "/"));
if (boost::starts_with(path, "/"))
return path.substr(1, path.size() - 1);
string cpath = path;
string cpwd = pwd;
while (boost::starts_with(cpath, "../")) {
cpath = cpath.substr(3, cpath.size() - 3);
size_t i = cpwd.find_last_of('/');
cpwd = i == string::npos ? "" : cpwd.substr(0, i);
}
if (!cpwd.size())
return cpath;
if (!cpath.size() || cpath == ".")
return cpwd;
return cpwd + '/' + cpath;
}

template <typename Dtype>
void Net<Dtype>::GetLearningRateAndWeightDecay() {
LOG(INFO) << "Collecting Learning Rate and Weight Decay.";
Expand Down
20 changes: 18 additions & 2 deletions src/caffe/proto/caffe.proto
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ message NetStateRule {
// NOTE
// Update the next available ID when you add a new LayerParameter field.
//
// LayerParameter next available ID: 42 (last added: exp_param)
// LayerParameter next available ID: 43 (last added: import_param)
message LayerParameter {
repeated string bottom = 2; // the name of the bottom blobs
repeated string top = 3; // the name of the top blobs
Expand All @@ -227,7 +227,7 @@ message LayerParameter {
// line above the enum. Update the next available ID when you add a new
// LayerType.
//
// LayerType next available ID: 39 (last added: EXP)
// LayerType next available ID: 40 (last added: IMPORT)
enum LayerType {
// "NONE" layer type is 0th enum element so that we don't cause confusion
// by defaulting to an existent LayerType (instead, should usually error if
Expand All @@ -252,6 +252,7 @@ message LayerParameter {
HINGE_LOSS = 28;
IM2COL = 11;
IMAGE_DATA = 12;
IMPORT = 39;
INFOGAIN_LOSS = 13;
INNER_PRODUCT = 14;
LRN = 15;
Expand Down Expand Up @@ -313,6 +314,7 @@ message LayerParameter {
optional HDF5OutputParameter hdf5_output_param = 14;
optional HingeLossParameter hinge_loss_param = 29;
optional ImageDataParameter image_data_param = 15;
optional ImportParameter import_param = 42;
optional InfogainLossParameter infogain_loss_param = 16;
optional InnerProductParameter inner_product_param = 17;
optional LRNParameter lrn_param = 18;
Expand Down Expand Up @@ -342,6 +344,20 @@ message LayerParameter {
optional V0LayerParameter layer = 1;
}

message Pair {
required string name = 1;
required string value = 2;
}

// Message that stores parameters used by ImportLayer
message ImportParameter {
// Proto file to import
required string net = 1;
// Variable names to replace before importing the file. Variables can
// be used in the file in this format: ${name}
repeated Pair var = 2;
}

// Message that stores parameters used to apply transformation
// to the data layer's data
message TransformationParameter {
Expand Down
21 changes: 21 additions & 0 deletions src/caffe/test/test_data/module.prototxt
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
layers: {
name: 'innerproduct'
type: INNER_PRODUCT
inner_product_param {
num_output: ${num_output}
weight_filler {
type: 'gaussian'
std: 0.01
}
bias_filler {
type: 'constant'
value: 0
}
}
blobs_lr: 1.
blobs_lr: 2.
weight_decay: 1.
weight_decay: 0.
bottom: '../data'
top: 'innerproduct'
}
Loading

0 comments on commit 2c9abcb

Please sign in to comment.