-
Notifications
You must be signed in to change notification settings - Fork 16
/
Copy pathPooling2D.m
111 lines (98 loc) · 4.33 KB
/
Pooling2D.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
classdef Pooling2D < handle & NetworkLayer
properties
ws; %window size
stride;
type;
max_idx; %for record the index of max-pooling
end
methods
function self = Pooling2D(ws, stride, type)
self.ws = ws;
self.stride = stride;
self.type = type;
end
function [] = setPar(self,in_size)
self.in_size = in_size;
new_size = [(self.in_size(1)-self.ws)/self.stride+1,(self.in_size(2)-self.ws)/self.stride+1];
if nnz(mod(new_size,1))> 0
error('pooling dimension mismatch');
end
self.out_size = [new_size, in_size(3)];
end
function object = gradCheckObject(self)
ws = 3;
stride = 2;
object = Pooling2D(ws, stride,self.type);
end
function [] = fprop(self)
self.OUT = zeros([(self.in_size(1)-self.ws)/self.stride+1, (self.in_size(2)-self.ws)/self.stride+1, self.in_size(3),self.numdata]);
if strcmp(self.type,'max')
self.OUT = -Inf*(self.OUT+1);
self.max_idx = zeros(size(self.OUT));
end
for i = 1 : self.ws %it's usually smaller
for j = 1 : self.ws
tmp = self.IN(i:self.stride:end-self.ws+i, j:self.stride:end-self.ws+j,:,:);
switch self.type
case 'max'
mask = tmp > self.OUT;
self.max_idx(mask) = (i-1)*self.ws+j;
self.OUT(mask) = tmp(mask);
case 'avg'
self.OUT = self.OUT + tmp;
otherwise
error('non-implemented type');
end
end
end
if strcmp(self.type,'avg')
self.OUT = self.OUT / (self.ws^2);
end
end
function [OUT] = fprop_rev(self, IN)
%for computing the gradient in ConvAutoencoder
OUT = zeros([(self.in_size(1)-self.ws)/self.stride+1, (self.in_size(2)-self.ws)/self.stride+1, self.in_size(3),self.numdata]);
if nnz(mod(OUT,1))>0
error('pooling dimension mismatch');
end
for i = 1 : self.ws %it's usually smaller
for j = 1 : self.ws
tmp = IN(i:self.stride:end-self.ws+i, j:self.stride:end-self.ws+j,:,:);
switch self.type
case 'max'
idx = self.max_idx == (i-1)*self.ws+j;
OUT(idx) = tmp(idx);
case 'avg'
OUT = OUT + tmp;
otherwise
error('non-implemented type');
end
end
end
if strcmp(self.type,'avg')
OUT = OUT / (self.ws^2);
end
end
function [f derivative] = bprop(self,f,derivative)
if ~self.skip_passdown
dX = zeros([self.in_size self.numdata]);
for i = 1 : self.ws
for j = 1 : self.ws
switch self.type
case 'max'
dX(i:self.stride:end-self.ws+i, j:self.stride:end-self.ws+j,:,:) = ...
dX(i:self.stride:end-self.ws+i, j:self.stride:end-self.ws+j,:,:) + derivative .* (self.max_idx == (i-1)*self.ws+j);
case 'avg'
dX(i:self.stride:end-self.ws+i, j:self.stride:end-self.ws+j,:,:) = ...
dX(i:self.stride:end-self.ws+i, j:self.stride:end-self.ws+j,:,:)+derivative;
end
end
end
if strcmp(self.type,'avg')
dX = dX / (self.ws^2);
end
derivative = dX;
end
end
end
end