-
Notifications
You must be signed in to change notification settings - Fork 81
/
Copy pathtrainModel.lua
190 lines (156 loc) · 5.9 KB
/
trainModel.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
-- use the GPU to process the whole batch in parallel
function trainModel(model,criterion,allData,trainInds,valInds,dataSplit,metaData)
local parameters,gradParameters = model:getParameters()
print('Number of Model Parameters ',parameters:size(1))
local dtype = 'torch.DoubleTensor'
if opt.useCUDA then
print('Using CUDA')
dtype = 'torch.CudaTensor'
else
print('Running on CPU - CUDA disabled')
end
local config = {
learningRate = opt.learningRate,
weightDecay = opt.weightDecay,
}
local bestfscore = 0
local bestResult = torch.zeros(6)
local timer = torch.Timer()
local nPrograms = #trainInds
print('Number of training examples ',#trainInds)
print('Number of validation examples ',#valInds)
-- pre-allocate memory for the batch
print('allocating batch memory')
--local batchProg = torch.zeros(opt.batchSize,opt.programLen):type(dtype)
local batchLabel = torch.zeros(opt.batchSize):type(dtype)
print('memory allocated')
--print(#batchProg)
if opt.useCUDA then
local freeMemory, totalMemory = cutorch.getMemoryUsage(opt.gpuid)
print('CUDA memory usage')
print('free ',freeMemory,'total ',totalMemory,'ratio ',freeMemory/totalMemory)
end
local gradMultiplier = torch.zeros(2):type(dtype)
if dataSplit.posNegRatio < 0.5 then
gradMultiplier[1] = 1 - dataSplit.posNegRatio
gradMultiplier[2] = dataSplit.posNegRatio
else
gradMultiplier[1] = dataSplit.posNegRatio
gradMultiplier[2] = 1 - dataSplit.posNegRatio
end
for e = 1,opt.nEpochs do
--batchProg:mul(0)
batchLabel:mul(0)
local nBatches = 0
local nSamples = 0
local epochError = 0
local order = torch.randperm(nPrograms)
for i = 1,(nPrograms - (nPrograms%opt.batchSize)),opt.batchSize do
nSamples = nSamples + opt.batchSize
nBatches = nBatches + 1
-- build the batch here
for k = 0,(opt.batchSize-1) do
--batchProg[{{k+1},{}}] = allData.program[trainInds[order[i + k]]]
batchLabel[{k+1}] = allData.label[trainInds[order[i + k]]]
end
local currProgramPtr = allData.programStartPtrs[trainInds[order[i]]]
local currProgramLen = allData.programLengths[trainInds[order[i]]]
local batchProg
if currProgramLen > opt.maxSequenceLength then
batchProg = torch.zeros(1,opt.maxSequenceLength):type(dtype)
local rndPtr = 0
if opt.dataAugTesting then
rndPtr = torch.floor(torch.rand(1)[1] * (currProgramLen - opt.maxSequenceLength - 1))
end
batchProg[{{1},{}}] = allData.program[{{currProgramPtr + rndPtr,currProgramPtr + rndPtr + opt.maxSequenceLength - 1}}]
else
batchProg = torch.zeros(1,currProgramLen):type(dtype)
batchProg[{{1},{}}] = allData.program[{{currProgramPtr,currProgramPtr + currProgramLen - 1}}]
end
--print(#batchProg)
--print(currProgramPtr,currProgramLen)
local feval = function(x)
local batchError = 0
if x ~= parameters then
parameters:copy(x)
end
gradParameters:zero()
local output = model:forward(batchProg)
local netError = criterion:forward(output,batchLabel)
batchError = batchError + netError
epochError = epochError + netError
local gradCriterion = criterion:backward(output,batchLabel)
if opt.weightClasses then
-- seems to be a bug in Torch with ClassNLLCriterion as it should
-- do this automatically ...
-- manually weight the classes to deal with imbalanced pos / neg samples
gradCriterion = gradCriterion:cmul(gradMultiplier)
end
model:backward(batchProg,gradCriterion)
return batchError,gradParameters
end
if opt.useRMSProp then
optim.rmsprop(feval, parameters, config)
else
optim.sgd(feval, parameters, config)
end
if isnan(epochError) then
print('training fail - Nan')
return 0
end
if epochError > 1e9 then
print('training fail - gradient exploded')
return 0
end
end
if (e == 50 or e == 75) and opt.decayLearningRate then
config.learningRate = config.learningRate * opt.weightDecayFrac
end
-- check the cross validation error
if e % opt.nSamplingEpochs == 0 or e == opt.nEpochs then
local time = timer:time().real
print('training time',string.format("%7.3f",time),' nPrograms in training ',nSamples)
timer:reset()
local nValPrograms = #valInds
local nTrainPrograms = #trainInds
print('nValPrograms',nValPrograms,'nTrainingPrograms',nTrainPrograms)
local valResult,valConfMat,valTime = testModel(allData,model,valInds,bestfscore)
if valResult.fscore > bestfscore then
bestfscore = valResult.fscore
bestResult[1] = valResult.accuracy
bestResult[2] = valResult.prec
bestResult[3] = valResult.recall
bestResult[4] = valResult.fscore
bestResult[5] = epochError/nBatches
bestResult[6] = valResult.testError
-- save the best model so far and the data split etc
if opt.saveModel then
local experimentData = {
opt = opt,
trainedModel = model:double(),
dataSplit = dataSplit,
metaData = metaData,
}
torch.save('./trainedNets/' .. opt.saveFileName .. '.th7',experimentData)
model:type(dtype)
parameters, gradParameters = model:getParameters()
collectgarbage()
end
end
print(e,'val ',epochError/nBatches,valResult.testError,valResult.accuracy,valResult.prec,valResult.recall,valResult.fscore)
print('testing time - val ',string.format("%7.3f",valTime),' nValPrograms',nValPrograms)
print(valConfMat)
local testResult,testConfMat,testTime = testModel(allData,model,trainInds,1)
print(e,'train ',epochError/nBatches,testResult.testError,testResult.accuracy,testResult.prec,testResult.recall,testResult.fscore)
print('testing time - train',string.format("%7.3f",testTime),' nTrainingPrograms',nTrainPrograms)
print(testConfMat)
print('--')
epochError = 0
nSamples = 0
nBatches = 0
collectgarbage()
end
end
print('Best Result ',bestResult[5],bestResult[6],bestResult[1],bestResult[2],bestResult[3],bestResult[4])
return model
end