Skip to content

Commit

Permalink
Merge pull request #2086 from longjon/python-net-spec
Browse files Browse the repository at this point in the history
Python net specification
  • Loading branch information
shelhamer committed Jun 30, 2015
2 parents af37efd + 1cdad89 commit 1d6cac2
Show file tree
Hide file tree
Showing 4 changed files with 325 additions and 0 deletions.
54 changes: 54 additions & 0 deletions examples/pycaffe/caffenet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from caffe import layers as L, params as P, to_proto
from caffe.proto import caffe_pb2

# helper function for common structures

def conv_relu(bottom, ks, nout, stride=1, pad=0, group=1):
conv = L.Convolution(bottom, kernel_size=ks, stride=stride,
num_output=nout, pad=pad, group=group)
return conv, L.ReLU(conv, in_place=True)

def fc_relu(bottom, nout):
fc = L.InnerProduct(bottom, num_output=nout)
return fc, L.ReLU(fc, in_place=True)

def max_pool(bottom, ks, stride=1):
return L.Pooling(bottom, pool=P.Pooling.MAX, kernel_size=ks, stride=stride)

def caffenet(lmdb, batch_size=256, include_acc=False):
data, label = L.Data(source=lmdb, backend=P.Data.LMDB, batch_size=batch_size, ntop=2,
transform_param=dict(crop_size=227, mean_value=[104, 117, 123], mirror=True))

# the net itself
conv1, relu1 = conv_relu(data, 11, 96, stride=4)
pool1 = max_pool(relu1, 3, stride=2)
norm1 = L.LRN(pool1, local_size=5, alpha=1e-4, beta=0.75)
conv2, relu2 = conv_relu(norm1, 5, 256, pad=2, group=2)
pool2 = max_pool(relu2, 3, stride=2)
norm2 = L.LRN(pool2, local_size=5, alpha=1e-4, beta=0.75)
conv3, relu3 = conv_relu(norm2, 3, 384, pad=1)
conv4, relu4 = conv_relu(relu3, 3, 384, pad=1, group=2)
conv5, relu5 = conv_relu(relu4, 3, 256, pad=1, group=2)
pool5 = max_pool(relu5, 3, stride=2)
fc6, relu6 = fc_relu(pool5, 4096)
drop6 = L.Dropout(relu6, in_place=True)
fc7, relu7 = fc_relu(drop6, 4096)
drop7 = L.Dropout(relu7, in_place=True)
fc8 = L.InnerProduct(drop7, num_output=1000)
loss = L.SoftmaxWithLoss(fc8, label)

if include_acc:
acc = L.Accuracy(fc8, label)
return to_proto(loss, acc)
else:
return to_proto(loss)

def make_net():
with open('train.prototxt', 'w') as f:
print >>f, caffenet('/path/to/caffe-train-lmdb')

with open('test.prototxt', 'w') as f:
print >>f, caffenet('/path/to/caffe-val-lmdb', batch_size=50, include_acc=True)

if __name__ == '__main__':
make_net()
1 change: 1 addition & 0 deletions python/caffe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
from .classifier import Classifier
from .detector import Detector
from . import io
from .net_spec import layers, params, NetSpec, to_proto
203 changes: 203 additions & 0 deletions python/caffe/net_spec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
"""Python net specification.
This module provides a way to write nets directly in Python, using a natural,
functional style. See examples/python_nets/caffenet.py for an example.
Currently this works as a thin wrapper around the Python protobuf interface,
with layers and parameters automatically generated for the "layers" and
"params" pseudo-modules, which are actually objects using __getattr__ magic
to generate protobuf messages.
Note that when using to_proto or Top.to_proto, names of intermediate blobs will
be automatically generated. To explicitly specify blob names, use the NetSpec
class -- assign to its attributes directly to name layers, and call
NetSpec.to_proto to serialize all assigned layers.
This interface is expected to continue to evolve as Caffe gains new capabilities
for specifying nets. In particular, the automatically generated layer names
are not guaranteed to be forward-compatible.
"""

from collections import OrderedDict

from .proto import caffe_pb2
from google import protobuf


def param_name_dict():
"""Find out the correspondence between layer names and parameter names."""

layer = caffe_pb2.LayerParameter()
# get all parameter names (typically underscore case) and corresponding
# type names (typically camel case), which contain the layer names
# (note that not all parameters correspond to layers, but we'll ignore that)
param_names = [s for s in dir(layer) if s.endswith('_param')]
param_type_names = [type(getattr(layer, s)).__name__ for s in param_names]
# strip the final '_param' or 'Parameter'
param_names = [s[:-len('_param')] for s in param_names]
param_type_names = [s[:-len('Parameter')] for s in param_type_names]
return dict(zip(param_type_names, param_names))


def to_proto(*tops):
"""Generate a NetParameter that contains all layers needed to compute
all arguments."""

if not isinstance(tops, tuple):
tops = (tops,)
layers = OrderedDict()
autonames = {}
for top in tops:
top.fn._to_proto(layers, {}, autonames)
net = caffe_pb2.NetParameter()
net.layer.extend(layers.values())
return net


def assign_proto(proto, name, val):
"""Assign a Python object to a protobuf message, based on the Python
type (in recursive fashion). Lists become repeated fields/messages, dicts
become messages, and other types are assigned directly."""

if isinstance(val, list):
if isinstance(val[0], dict):
for item in val:
proto_item = getattr(proto, name).add()
for k, v in item.iteritems():
assign_proto(proto_item, k, v)
else:
getattr(proto, name).extend(val)
elif isinstance(val, dict):
for k, v in val.iteritems():
assign_proto(getattr(proto, name), k, v)
else:
setattr(proto, name, val)


class Top(object):
"""A Top specifies a single output blob (which could be one of several
produced by a layer.)"""

def __init__(self, fn, n):
self.fn = fn
self.n = n

def to_proto(self):
"""Generate a NetParameter that contains all layers needed to compute
this top."""

return to_proto(self)


class Function(object):
"""A Function specifies a layer, its parameters, and its inputs (which
are Tops from other layers)."""

def __init__(self, type_name, inputs, params):
self.type_name = type_name
self.inputs = inputs
self.params = params
self.ntop = self.params.get('ntop', 1)
# use del to make sure kwargs are not double-processed as layer params
if 'ntop' in self.params:
del self.params['ntop']
self.in_place = self.params.get('in_place', False)
if 'in_place' in self.params:
del self.params['in_place']
self.tops = tuple(Top(self, n) for n in range(self.ntop))

def _get_name(self, top, names, autonames):
if top not in names:
n = autonames.setdefault(top.fn.type_name, 1)
autonames[top.fn.type_name] += 1
names[top] = top.fn.type_name + str(n)
return names[top]

def _to_proto(self, layers, names, autonames):
if self in layers:
return
bottom_names = []
for inp in self.inputs:
inp.fn._to_proto(layers, names, autonames)
bottom_names.append(layers[inp.fn].top[inp.n])
layer = caffe_pb2.LayerParameter()
layer.type = self.type_name
layer.bottom.extend(bottom_names)

if self.in_place:
layer.top.extend(layer.bottom)
else:
for top in self.tops:
layer.top.append(self._get_name(top, names, autonames))
layer.name = self._get_name(self.tops[0], names, autonames)

for k, v in self.params.iteritems():
# special case to handle generic *params
if k.endswith('param'):
assign_proto(layer, k, v)
else:
try:
assign_proto(getattr(layer,
_param_names[self.type_name] + '_param'), k, v)
except (AttributeError, KeyError):
assign_proto(layer, k, v)

layers[self] = layer


class NetSpec(object):
"""A NetSpec contains a set of Tops (assigned directly as attributes).
Calling NetSpec.to_proto generates a NetParameter containing all of the
layers needed to produce all of the assigned Tops, using the assigned
names."""

def __init__(self):
super(NetSpec, self).__setattr__('tops', OrderedDict())

def __setattr__(self, name, value):
self.tops[name] = value

def __getattr__(self, name):
return self.tops[name]

def to_proto(self):
names = {v: k for k, v in self.tops.iteritems()}
autonames = {}
layers = OrderedDict()
for name, top in self.tops.iteritems():
top.fn._to_proto(layers, names, autonames)
net = caffe_pb2.NetParameter()
net.layer.extend(layers.values())
return net


class Layers(object):
"""A Layers object is a pseudo-module which generates functions that specify
layers; e.g., Layers().Convolution(bottom, kernel_size=3) will produce a Top
specifying a 3x3 convolution applied to bottom."""

def __getattr__(self, name):
def layer_fn(*args, **kwargs):
fn = Function(name, args, kwargs)
if fn.ntop == 1:
return fn.tops[0]
else:
return fn.tops
return layer_fn


class Parameters(object):
"""A Parameters object is a pseudo-module which generates constants used
in layer parameters; e.g., Parameters().Pooling.MAX is the value used
to specify max pooling."""

def __getattr__(self, name):
class Param:
def __getattr__(self, param_name):
return getattr(getattr(caffe_pb2, name + 'Parameter'), param_name)
return Param()


_param_names = param_name_dict()
layers = Layers()
params = Parameters()
67 changes: 67 additions & 0 deletions python/caffe/test/test_net_spec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import unittest
import tempfile
import caffe
from caffe import layers as L
from caffe import params as P

def lenet(batch_size):
n = caffe.NetSpec()
n.data, n.label = L.DummyData(shape=[dict(dim=[batch_size, 1, 28, 28]),
dict(dim=[batch_size, 1, 1, 1])],
transform_param=dict(scale=1./255), ntop=2)
n.conv1 = L.Convolution(n.data, kernel_size=5, num_output=20,
weight_filler=dict(type='xavier'))
n.pool1 = L.Pooling(n.conv1, kernel_size=2, stride=2, pool=P.Pooling.MAX)
n.conv2 = L.Convolution(n.pool1, kernel_size=5, num_output=50,
weight_filler=dict(type='xavier'))
n.pool2 = L.Pooling(n.conv2, kernel_size=2, stride=2, pool=P.Pooling.MAX)
n.ip1 = L.InnerProduct(n.pool2, num_output=500,
weight_filler=dict(type='xavier'))
n.relu1 = L.ReLU(n.ip1, in_place=True)
n.ip2 = L.InnerProduct(n.relu1, num_output=10,
weight_filler=dict(type='xavier'))
n.loss = L.SoftmaxWithLoss(n.ip2, n.label)
return n.to_proto()

def anon_lenet(batch_size):
data, label = L.DummyData(shape=[dict(dim=[batch_size, 1, 28, 28]),
dict(dim=[batch_size, 1, 1, 1])],
transform_param=dict(scale=1./255), ntop=2)
conv1 = L.Convolution(data, kernel_size=5, num_output=20,
weight_filler=dict(type='xavier'))
pool1 = L.Pooling(conv1, kernel_size=2, stride=2, pool=P.Pooling.MAX)
conv2 = L.Convolution(pool1, kernel_size=5, num_output=50,
weight_filler=dict(type='xavier'))
pool2 = L.Pooling(conv2, kernel_size=2, stride=2, pool=P.Pooling.MAX)
ip1 = L.InnerProduct(pool2, num_output=500,
weight_filler=dict(type='xavier'))
relu1 = L.ReLU(ip1, in_place=True)
ip2 = L.InnerProduct(relu1, num_output=10,
weight_filler=dict(type='xavier'))
loss = L.SoftmaxWithLoss(ip2, label)
return loss.to_proto()

class TestNetSpec(unittest.TestCase):
def load_net(self, net_proto):
f = tempfile.NamedTemporaryFile(delete=False)
f.write(str(net_proto))
f.close()
return caffe.Net(f.name, caffe.TEST)

def test_lenet(self):
"""Construct and build the Caffe version of LeNet."""

net_proto = lenet(50)
# check that relu is in-place
self.assertEqual(net_proto.layer[6].bottom,
net_proto.layer[6].top)
net = self.load_net(net_proto)
# check that all layers are present
self.assertEqual(len(net.layers), 9)

# now the check the version with automatically-generated layer names
net_proto = anon_lenet(50)
self.assertEqual(net_proto.layer[6].bottom,
net_proto.layer[6].top)
net = self.load_net(net_proto)
self.assertEqual(len(net.layers), 9)

0 comments on commit 1d6cac2

Please sign in to comment.