-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathutils.lua
312 lines (251 loc) · 9.01 KB
/
utils.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
stringx = require 'pl.stringx'
tds = require 'tds'
function getPhoneClasses(filename)
local phone2class, classes = {}, {}
for line in io.lines(filename) do
local phone, class = unpack(line:split(' '))
phone2class[phone] = class
if not classes[class] then
table.insert(classes, class)
end
end
return phone2class, classes
end
function getLabels(dirPath)
local dbTrans = lmdb.env { Path = dirPath .. '/trans', Name = 'trans' }
dbTrans:open()
local readerTrans = dbTrans:txn(true)
local dataSize = dbTrans:stat()['entries']
local label2idx, idx2label = {}, {}
for i = 1,dataSize do
local labels = readerTrans:get(i):split(' ')
for t, label in pairs(labels) do
if not label2idx[label] then
idx2label[#idx2label+1] = label
label2idx[label] = #idx2label
end
end
end
dbTrans:close()
return label2idx, idx2label
end
function getFreq(dirPath)
local dbSpect = lmdb.env { Path = dirPath .. '/spect', Name = 'spect' }
dbSpect:open()
local readerSpect = dbSpect:txn(true)
local tensor = readerSpect:get(1):float()
local freq = tensor:size(1)
dbSpect:close()
return freq
end
function getSplitDBs(dirPath)
local dbSpect = lmdb.env { Path = dirPath .. '/spect', Name = 'spect' }
local dbTrans = lmdb.env { Path = dirPath .. '/trans', Name = 'trans' }
local dbTimes = lmdb.env { Path = dirPath .. '/times', Name = 'times' }
dbSpect:open()
local dataSize = dbSpect:stat()['entries']
local readerSpect = dbSpect:txn(true)
local tensor = readerSpect:get(1):float()
local freq = tensor:size(1)
dbSpect:close()
return dbSpect, dbTrans, dbTimes, dataSize
end
-- TODO merge with getSplitDBs
function getSplitDBsUttids(dirPath)
local dbSpect = lmdb.env { Path = dirPath .. '/spect', Name = 'spect' }
local dbTrans = lmdb.env { Path = dirPath .. '/trans', Name = 'trans' }
local dbUttids = lmdb.env { Path = dirPath .. '/uttid', Name = 'uttid' }
dbSpect:open()
local dataSize = dbSpect:stat()['entries']
local readerSpect = dbSpect:txn(true)
local tensor = readerSpect:get(1):float()
local freq = tensor:size(1)
dbSpect:close()
return dbSpect, dbTrans, dbUttids, dataSize
end
function loadData(dbSpect, dbTrans, dbTimes)
local tensors = tds.Vec()
--local targets = {}
local transcripts = {}
local times = {}
local freq = 0
dbSpect:open(); local readerSpect = dbSpect:txn(true)
dbTrans:open(); local readerTrans = dbTrans:txn(true)
dbTimes:open(); local readerTimes = dbTimes:txn(true)
local size = dbSpect:stat()['entries']
-- read out all the data and store in lists
for x = 1, size do
local tensor = readerSpect:get(x):float()
local transcript = readerTrans:get(x)
local curTimes = readerTimes:get(x):long()
freq = tensor:size(1)
tensors:insert(tensor)
--table.insert(targets, self.mapper:encodeString(transcript))
table.insert(transcripts, transcript)
table.insert(times, curTimes)
end
readerSpect:abort(); dbSpect:close()
readerTrans:abort(); dbTrans:close()
readerTimes:abort(); dbTimes:close()
return tensors, transcripts, times
end
-- TODO merge with loadData
function loadDataUttids(dbSpect, dbTrans, dbUttids)
local tensors = tds.Vec()
--local targets = {}
local transcripts = {}
local uttids = {}
local freq = 0
dbSpect:open(); local readerSpect = dbSpect:txn(true)
dbTrans:open(); local readerTrans = dbTrans:txn(true)
dbUttids:open(); local readerUttids = dbUttids:txn(true)
local size = dbSpect:stat()['entries']
-- read out all the data and store in lists
for x = 1, size do
local tensor = readerSpect:get(x):float()
local transcript = readerTrans:get(x)
local curUttid = readerUttids:get(x)
freq = tensor:size(1)
tensors:insert(tensor)
--table.insert(targets, self.mapper:encodeString(transcript))
table.insert(transcripts, transcript)
table.insert(uttids, curUttid)
end
readerSpect:abort(); dbSpect:close()
readerTrans:abort(); dbTrans:close()
readerUttids:abort(); dbUttids:close()
return tensors, transcripts, uttids
end
function nextBatch(indices, spects, transcripts, times)
local batchTensors = tds.Vec()
--local targets = {}
local batchTranscripts = {}
local batchTimes = {}
local maxLength = 0
local freq = 0
local size = indices:size(1)
local batchSizes = torch.Tensor(#indices)
-- reads out a batch and store in lists
for x = 1, size do
local ind = indices[x]
local tensor = spects[ind]
local transcript = transcripts[ind]
local curTimes = times[ind]
freq = tensor:size(1)
batchSizes[x] = tensor:size(2)
if maxLength < tensor:size(2) then maxLength = tensor:size(2) end -- find the max len in this batch
batchTensors:insert(tensor)
--table.insert(targets, self.mapper:encodeString(transcript))
table.insert(batchTranscripts, transcript)
table.insert(batchTimes, curTimes)
end
local batchInputs = torch.Tensor(size, 1, freq, maxLength):zero()
for ind, tensor in ipairs(batchTensors) do
batchInputs[ind][1]:narrow(2, 1, tensor:size(2)):copy(tensor)
end
--return inputs, targets, sizes, transcripts, times
return batchInputs, sizes, batchTranscripts, batchTimes
end
function nextBatchUttids(indices, spects, transcripts, uttids)
local batchTensors = tds.Vec()
--local targets = {}
local batchTranscripts = {}
local batchUttids = {}
local maxLength = 0
local freq = 0
local size = indices:size(1)
local batchSizes = torch.Tensor(#indices)
-- reads out a batch and store in lists
for x = 1, size do
local ind = indices[x]
local tensor = spects[ind]
local transcript = transcripts[ind]
local curUttid = uttids[ind]
freq = tensor:size(1)
batchSizes[x] = tensor:size(2)
if maxLength < tensor:size(2) then maxLength = tensor:size(2) end -- find the max len in this batch
batchTensors:insert(tensor)
--table.insert(targets, self.mapper:encodeString(transcript))
table.insert(batchTranscripts, transcript)
table.insert(batchUttids, curUttid)
end
local batchInputs = torch.Tensor(size, 1, freq, maxLength):zero()
for ind, tensor in ipairs(batchTensors) do
batchInputs[ind][1]:narrow(2, 1, tensor:size(2)):copy(tensor)
end
--return inputs, targets, sizes, transcripts, uttids
return batchInputs, sizes, batchTranscripts, batchUttids
end
--[[
find frame label (a phoneme string)
transcript: a string of phonemes
times: a tensor of integers representing phoneme end times in the input data
t: an index of current frame that is input to the classifier (and output of deep speech model)
layer: a name of the layer whose output is input to the classifier
convStepSize: step size of convolution in time dimension
sampleRate: sampling rate of the original audio
spectStride: stride of the spectrogram
spectwindowSize: window size of the spectrogram
ignoreSilence: whether to ignore begin/end silence
--]]
function getFrameLabel(transcript, times, t, layer, convStepSize, sampleRate, spectStride, spectWindowSize, ignoreSilence)
local sampleRate = 16000 or sampleRate
local spectStride = 0.01 or spectStride
local spectWindowSize = 0.02 or spectWindowSize
local ignoreSilence = ignoreSilence or true
local convStepSize = convStepSize or 2
local deepSpeechInputFrame -- index of input frame to the deep speech model
if layer == 'cnn' or layer == 'cnn2' or stringx.startswith(layer, 'rnn') then
deepSpeechInputFrame = ( ((t-1)*convStepSize+11) - 1 )*convStepSize+11
elseif layer == 'cnn1' then
deepSpeechInputFrame = (t-1)*convStepSize+11
elseif layer == 'input' then
deepSpeechInputFrame = t
else
error('Unsupported layer ' .. layer .. ' in getFrameLabelIdx')
end
local windowStart = deepSpeechInputFrame*sampleRate*spectStride
local windowMiddle = math.floor(windowStart + spectWindowSize*sampleRate/2)
-- find end time
for i = 1,times:size(1) do
if windowMiddle < times[i] then
-- if ignore begin/end silence
if ignoreSilence and (i == 1 or i == times:size(1)) then
return nil
else
return transcript:split(' ')[i]
end
end
end
-- if frame is out of times, it will be ignored
return nil
end
function getWindowedInput(windowSize, repr, t, k, timeDim)
local windowedInput = torch.zeros(2*windowSize+1, repr[t][k]:nElement())
-- TODO fix this to use repr to create new zero tensor
windowedInput = windowedInput:cuda()
-- TODO vectorize
for w = 1, 2*windowSize+1 do
local curFrameId = t-windowSize-1+w
if curFrameId >= 1 and curFrameId <= repr:size(timeDim) then
windowedInput[w] = repr[curFrameId][k]
end
end
return windowedInput:view(windowedInput:nElement())
end
function writeMatrixToFile(mat, file, sep)
assert(mat:dim() == 2, 'wrong matrix dimension: ' .. mat:dim())
local sep = sep or ' '
local f = assert(io.open(file, 'w'))
for i = 1, mat:size(1) do
for j = 1, mat:size(2) do
f:write(mat[i][j])
if j == mat:size(2) then
f:write('\n')
else
f:write(sep)
end
end
end
f:close()
end