Skip to content

Commit

Permalink
[pycaffe] WIP on net spec
Browse files Browse the repository at this point in the history
  • Loading branch information
longjon committed Jun 18, 2015
1 parent 2ddadd2 commit 186c40c
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 21 deletions.
2 changes: 1 addition & 1 deletion python/caffe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
from .classifier import Classifier
from .detector import Detector
from . import io
from .layers import layers, params, to_proto
from .layers import layers, params, NetSpec
54 changes: 34 additions & 20 deletions python/caffe/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,26 +24,19 @@ def assign_proto(proto, name, val):
else:
setattr(proto, name, val)

def to_proto(tops, names=None):
if not isinstance(tops, tuple):
tops = (tops,)
if names is None:
names = {}
autonames = {}
layers = OrderedDict()
for top in tops:
top.fn._to_proto(layers, names, autonames)

net = caffe_pb2.NetParameter()
net.layer.extend(layers.values())
return net

class Top:
class Top(object):
def __init__(self, fn, n):
self.fn = fn
self.n = n

class Function:
def to_proto(self):
layers = OrderedDict()
self.fn._to_proto(layers, {}, {})
net = caffe_pb2.NetParameter()
net.layer.extend(layers.values())
return net

class Function(object):
def __init__(self, type_name, inputs, params):
self.type_name = type_name
self.inputs = inputs
Expand All @@ -64,10 +57,11 @@ def _get_name(self, top, names, autonames):
return names[top]

def _to_proto(self, layers, names, autonames):
if self in layers:
return
bottom_names = []
for inp in self.inputs:
if inp.fn not in layers:
inp.fn._to_proto(layers, names, autonames)
inp.fn._to_proto(layers, names, autonames)
bottom_names.append(layers[inp.fn].top[inp.n])
layer = caffe_pb2.LayerParameter()
layer.type = self.type_name
Expand All @@ -92,7 +86,27 @@ def _to_proto(self, layers, names, autonames):

layers[self] = layer

class Layers:
class NetSpec(object):
def __init__(self):
super(NetSpec, self).__setattr__('tops', OrderedDict())

def __setattr__(self, name, value):
self.tops[name] = value

def __getattr__(self, name):
return self.tops[name]

def to_proto(self):
names = {v: k for k, v in self.tops.iteritems()}
autonames = {}
layers = OrderedDict()
for name, top in self.tops.iteritems():
top.fn._to_proto(layers, names, autonames)
net = caffe_pb2.NetParameter()
net.layer.extend(layers.values())
return net

class Layers(object):
def __getattr__(self, name):
def layer_fn(*args, **kwargs):
fn = Function(name, args, kwargs)
Expand All @@ -102,7 +116,7 @@ def layer_fn(*args, **kwargs):
return fn.tops
return layer_fn

class Parameters:
class Parameters(object):
def __getattr__(self, name):
class Param:
def __getattr__(self, param_name):
Expand Down

0 comments on commit 186c40c

Please sign in to comment.