-
Notifications
You must be signed in to change notification settings - Fork 16
/
Copy pathCNN.m
121 lines (95 loc) · 4.19 KB
/
CNN.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
classdef CNN < handle & NetworkLayer
%Convolutional Neural Network based on Mathworks's default filter2
%-------------------
%TODO: larger stride?
%-------------------
%input is width*height*num_channels*num_data
%output is width*height*numunits*num_data
properties
ws;
numunits;
numchannels;
weights; %ws*ws*numchannels*numunits
biases; %numunits*1
l2_C = 1e-4;
init_weight = 1e-3;
dweights;
dbiases;
end
methods
function self = CNN(numunits, ws)
self.ws = ws;
self.numunits = numunits;
end
function [] = setPar(self,feadim) %feadim = (W, H, numchannels)
self.in_size = feadim;
self.out_size = [feadim(1)-self.ws+1, feadim(2)-self.ws+1, self.numunits];
self.numchannels = feadim(3);
if isempty(self.weights) || nnz(self.in_size ~= feadim)>0
self.weights = self.init_weight*Utils.randn([self.ws,self.ws,self.numchannels,self.numunits]);
self.biases = Utils.zeros([self.numunits,1]);
end
self.paramNum = numel(self.weights) + numel(self.biases);
end
function [] = reset(self)
self.weights = self.init_weights*Utils.randn(size(self.weights));
self.biases = Utils.zeros(size(self.biases));
end
function object = gradCheckObject(self)
ws = 3;
numunits = 2;
object = CNN(numunits,ws);
end
function fprop(self)
self.OUT = Utils.zeros([self.in_size(1)-self.ws+1, self.in_size(2)-self.ws+1, self.numunits, self.numdata]);
%this is very slow!!
self.OUT = Utils.convInference(self.IN,self.weights);
for u = 1 : self.numunits
self.OUT(:,:,u,:) = self.OUT(:,:,u,:) + self.biases(u);
end
self.OUT = Utils.sigmoid(self.OUT);
end
function [f derivative] = bprop(self,f,derivative)
if isempty(self.dweights)
self.dweights = Utils.zeros(size(self.weights));
self.dbiases = Utils.zeros(size(self.biases));
end
da = self.OUT.*(1-self.OUT).*derivative; %numunits*numdata
if self.skip_update ~= true
self.dweights = Utils.convGradient(self.IN, da);
self.dweights = self.dweights + self.l2_C*self.weights;
self.dbiases = squeeze(sum(sum(sum(da,1),2),4));
f = f + 0.5*self.l2_C*norm(self.weights(:))^2;
end
if ~self.skip_passdown
derivative = Utils.convReconstruct(da, self.weights);
end
end
function clearTempData(self)
self.IN = [];
self.OUT= [];
self.dweights = [];
self.dbiases = [];
end
function param = getParam(self)
if ~self.skip_update
param = {self.weights, self.biases};
else
param = {};
end
end
function param = getGradParam(self)
if ~self.skip_update
param = {self.dweights, self.dbiases};
else
param = {};
end
end
function setParam(self,paramvec)
if ~self.skip_update
self.weights = reshape(paramvec(1:numel(self.weights)),size(self.weights));
self.biases = reshape(paramvec(numel(self.weights)+1:end),size(self.biases));
end
end
end
end