forked from Element-Research/rnn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathMaskZeroCriterion.lua
101 lines (86 loc) · 3.18 KB
/
MaskZeroCriterion.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
------------------------------------------------------------------------
--[[ MaskZeroCriterion ]]--
-- Decorator that zeros err and gradInputs of the encapsulated criterion
-- for commensurate input rows which are tensors of zeros
------------------------------------------------------------------------
local MaskZeroCriterion, parent = torch.class("nn.MaskZeroCriterion", "nn.Criterion")
function MaskZeroCriterion:__init(criterion, nInputDim)
parent.__init(self)
self.criterion = criterion
assert(torch.isTypeOf(criterion, 'nn.Criterion'))
assert(torch.type(nInputDim) == 'number', 'Expecting nInputDim number at arg 1')
self.nInputDim = nInputDim
end
function MaskZeroCriterion:recursiveGetFirst(input)
if torch.type(input) == 'table' then
return self:recursiveGetFirst(input[1])
else
assert(torch.isTensor(input))
return input
end
end
function MaskZeroCriterion:recursiveMask(dst, src, mask)
if torch.type(src) == 'table' then
dst = torch.type(dst) == 'table' and dst or {}
for k,v in ipairs(src) do
dst[k] = self:recursiveMask(dst[k], v, mask)
end
else
assert(torch.isTensor(src))
dst = torch.isTensor(dst) and dst or src.new()
dst:index(src, 1, mask)
end
return dst
end
function MaskZeroCriterion:updateOutput(input, target)
-- recurrent module input is always the first one
local rmi = self:recursiveGetFirst(input):contiguous()
if rmi:dim() == self.nInputDim then
error("does not support online (i.e. non-batch) mode")
elseif rmi:dim() - 1 == self.nInputDim then
rmi = rmi:view(rmi:size(1), -1) -- collapse non-batch dims
else
error("nInputDim error: "..rmi:dim()..", "..self.nInputDim)
end
-- build mask
local vectorDim = rmi:dim()
self._zeroMask = self._zeroMask or rmi.new()
self._zeroMask:norm(rmi, 2, vectorDim)
local zeroMask = self._zeroMask
if torch.isTypeOf(zeroMask, 'torch.CudaTensor') then
self.__zeroMask = self.__zeroMask or torch.FloatTensor()
self.__zeroMask:resize(self._zeroMask:size()):copy(self._zeroMask)
zeroMask = self._zeroMask
end
self.zeroMask = self.zeroMask or torch.LongTensor()
self.zeroMask:resize(self._zeroMask:size(1)):zero()
local i, j = 0, 0
zeroMask:apply(function(norm)
i = i + 1
if norm ~= 0 then
j = j + 1
self.zeroMask[j] = i
end
end)
self.zeroMask:resize(j)
self.input = self:recursiveMask(self.input, input, self.zeroMask)
self.target = self:recursiveMask(self.target, target, self.zeroMask)
-- forward through decorated criterion
self.output = self.criterion:updateOutput(self.input, self.target)
return self.output
end
function MaskZeroCriterion:updateGradInput(input, target)
self._gradInput = self.criterion:updateGradInput(self.input, self.target)
self.gradInput:resizeAs(input):zero()
self.gradInput:indexCopy(1, self.zeroMask, self._gradInput)
return self.gradInput
end
function MaskZeroCriterion:type(type, ...)
self.zeroMask = nil
self._zeroMask = nil
self.__zeroMask = nil
self.input = nil
self.target = nil
self._gradInput = nil
return parent.type(self, type, ...)
end