forked from Element-Research/rnn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathRecurrentAttention.lua
184 lines (156 loc) · 7.03 KB
/
RecurrentAttention.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
------------------------------------------------------------------------
--[[ RecurrentAttention ]]--
-- Ref. A. http://papers.nips.cc/paper/5542-recurrent-models-of-visual-attention.pdf
-- B. http://incompleteideas.net/sutton/williams-92.pdf
-- module which takes an RNN as argument with other
-- hyper-parameters such as the maximum number of steps,
-- action (actions sampling module like ReinforceNormal) and
------------------------------------------------------------------------
local RecurrentAttention, parent = torch.class("nn.RecurrentAttention", "nn.AbstractSequencer")
function RecurrentAttention:__init(rnn, action, nStep, hiddenSize)
parent.__init(self)
assert(torch.isTypeOf(action, 'nn.Module'))
assert(torch.type(nStep) == 'number')
assert(torch.type(hiddenSize) == 'table')
assert(torch.type(hiddenSize[1]) == 'number', "Does not support table hidden layers" )
self.rnn = rnn
-- we can decorate the module with a Recursor to make it AbstractRecurrent
self.rnn = (not torch.isTypeOf(rnn, 'nn.AbstractRecurrent')) and nn.Recursor(rnn) or rnn
-- backprop through time (BPTT) will be done online (in reverse order of forward)
self.rnn:backwardOnline()
for i,modula in ipairs(self.rnn:listModules()) do
if torch.isTypeOf(modula, "nn.AbstractRecurrent") then
modula.copyInputs = false
modula.copyGradOutputs = false
end
end
-- samples an x,y actions for each example
self.action = (not torch.isTypeOf(action, 'nn.AbstractRecurrent')) and nn.Recursor(action) or action
self.action:backwardOnline()
self.hiddenSize = hiddenSize
self.nStep = nStep
self.modules = {self.rnn, self.action}
self.output = {} -- rnn output
self.actions = {} -- action output
self.forwardActions = false
self.gradHidden = {}
end
function RecurrentAttention:updateOutput(input)
self.rnn:forget()
self.action:forget()
local nDim = input:dim()
for step=1,self.nStep do
if step == 1 then
-- sample an initial starting actions by forwarding zeros through the action
self._initInput = self._initInput or input.new()
self._initInput:resize(input:size(1),table.unpack(self.hiddenSize)):zero()
self.actions[1] = self.action:updateOutput(self._initInput)
else
-- sample actions from previous hidden activation (rnn output)
self.actions[step] = self.action:updateOutput(self.output[step-1])
end
-- rnn handles the recurrence internally
local output = self.rnn:updateOutput{input, self.actions[step]}
self.output[step] = self.forwardActions and {output, self.actions[step]} or output
end
return self.output
end
function RecurrentAttention:updateGradInput(input, gradOutput)
assert(self.rnn.step - 1 == self.nStep, "inconsistent rnn steps")
assert(torch.type(gradOutput) == 'table', "expecting gradOutput table")
assert(#gradOutput == self.nStep, "gradOutput should have nStep elements")
-- back-propagate through time (BPTT)
for step=self.nStep,1,-1 do
-- 1. backward through the action layer
local gradOutput_, gradAction_ = gradOutput[step]
if self.forwardActions then
gradOutput_, gradAction_ = unpack(gradOutput[step])
else
-- Note : gradOutput is ignored by REINFORCE modules so we give a zero Tensor instead
self._gradAction = self._gradAction or self.action.output.new()
if not self._gradAction:isSameSizeAs(self.action.output) then
self._gradAction:resizeAs(self.action.output):zero()
end
gradAction_ = self._gradAction
end
if step == self.nStep then
self.gradHidden[step] = nn.rnn.recursiveCopy(self.gradHidden[step], gradOutput_)
else
-- gradHidden = gradOutput + gradAction
nn.rnn.recursiveAdd(self.gradHidden[step], gradOutput_)
end
if step == 1 then
-- backward through initial starting actions
self.action:updateGradInput(self._initInput, gradAction_)
else
local gradAction = self.action:updateGradInput(self.output[step-1], gradAction_)
self.gradHidden[step-1] = nn.rnn.recursiveCopy(self.gradHidden[step-1], gradAction)
end
-- 2. backward through the rnn layer
local gradInput = self.rnn:updateGradInput(input, self.gradHidden[step])[1]
if step == self.nStep then
self.gradInput:resizeAs(gradInput):copy(gradInput)
else
self.gradInput:add(gradInput)
end
end
return self.gradInput
end
function RecurrentAttention:accGradParameters(input, gradOutput, scale)
assert(self.rnn.step - 1 == self.nStep, "inconsistent rnn steps")
assert(torch.type(gradOutput) == 'table', "expecting gradOutput table")
assert(#gradOutput == self.nStep, "gradOutput should have nStep elements")
-- back-propagate through time (BPTT)
for step=self.nStep,1,-1 do
-- 1. backward through the action layer
local gradAction_ = self.forwardActions and gradOutput[step][2] or self._gradAction
if step == 1 then
-- backward through initial starting actions
self.action:accGradParameters(self._initInput, gradAction_, scale)
else
self.action:accGradParameters(self.output[step-1], gradAction_, scale)
end
-- 2. backward through the rnn layer
self.rnn:accGradParameters(input, self.gradHidden[step], scale)
end
end
function RecurrentAttention:accUpdateGradParameters(input, gradOutput, lr)
assert(self.rnn.step - 1 == self.nStep, "inconsistent rnn steps")
assert(torch.type(gradOutput) == 'table', "expecting gradOutput table")
assert(#gradOutput == self.nStep, "gradOutput should have nStep elements")
-- backward through the action layers
for step=self.nStep,1,-1 do
-- 1. backward through the action layer
local gradAction_ = self.forwardActions and gradOutput[step][2] or self._gradAction
if step == 1 then
-- backward through initial starting actions
self.action:accUpdateGradParameters(self._initInput, gradAction_, lr)
else
-- Note : gradOutput is ignored by REINFORCE modules so we give action.output as a dummy variable
self.action:accUpdateGradParameters(self.output[step-1], gradAction_, lr)
end
-- 2. backward through the rnn layer
self.rnn:accUpdateGradParameters(input, self.gradHidden[step], lr)
end
end
function RecurrentAttention:type(type)
self._input = nil
self._actions = nil
self._crop = nil
self._pad = nil
self._byte = nil
return parent.type(self, type)
end
function RecurrentAttention:__tostring__()
local tab = ' '
local line = '\n'
local ext = ' | '
local extlast = ' '
local last = ' ... -> '
local str = torch.type(self)
str = str .. ' {'
str = str .. line .. tab .. 'action : ' .. tostring(self.action):gsub(line, line .. tab .. ext)
str = str .. line .. tab .. 'rnn : ' .. tostring(self.rnn):gsub(line, line .. tab .. ext)
str = str .. line .. '}'
return str
end