forked from Element-Research/rnn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathSequencerCriterion.lua
59 lines (55 loc) · 2.34 KB
/
SequencerCriterion.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
------------------------------------------------------------------------
--[[ SequencerCriterion ]]--
-- Applies a criterion to each of the inputs and targets in the
-- corresponding input and target Tables.
-- Useful for nn.Repeater and nn.Sequencer.
-- WARNING : assumes that the decorated criterion is stateless, i.e.
-- the backward doesn't need to be preceded by a commensurate forward.
------------------------------------------------------------------------
local SequencerCriterion, parent = torch.class('nn.SequencerCriterion', 'nn.Criterion')
function SequencerCriterion:__init(criterion)
parent.__init(self)
self.criterion = criterion
if torch.isTypeOf(criterion, 'nn.ModuleCriterion') then
error("SequencerCriterion shouldn't decorate a ModuleCriterion. "..
"Instead, try the other way around : "..
"ModuleCriterion decorates a SequencerCriterion. "..
"Its modules can also be similarly decorated with a Sequencer.")
end
self.gradInput = {}
self._gradInput = {}
end
function SequencerCriterion:updateOutput(inputTable, targetTable)
self.output = 0
for i,input in ipairs(inputTable) do
self.output = self.output + self.criterion:forward(input, targetTable[i])
end
return self.output
end
function SequencerCriterion:updateGradInput(inputTable, targetTable)
for i,input in ipairs(inputTable) do
self.gradInput[i] = nn.rnn.recursiveCopy(
self.gradInput[i] or table.remove(self._gradInput, 1),
self.criterion:backward(input, targetTable[i])
)
end
-- remove extra gradInput tensors (save for later)
for i=#inputTable+1,#self.gradInput do
table.insert(self._gradInput, self.gradInput[i])
self.gradInput[i] = nil
end
if #inputTable >= 3 and not self.isStateless then
-- make sure the criterion is stateless
local gradInput
for i = 1,3 do
self.criterion:forward(inputTable[i], targetTable[i])
gradInput = self.criterion:backward(inputTable[i], targetTable[i])
nn.utils.recursiveAdd(gradInput -1, self.gradInput[i])
if math.abs(nn.rnn.recursiveSum(gradInput)) < 0.0001 then
error("SequencerCriterion only decorates stateless criterions : "..tostring(self.criterion))
end
end
self.isStateless = true -- test should only be run once
end
return self.gradInput
end