-
Notifications
You must be signed in to change notification settings - Fork 81
/
Copy pathsplitMalwareData.lua
165 lines (135 loc) · 4.69 KB
/
splitMalwareData.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
-- run this program once given a new dataset
-- saves the test / train split to disk
-- later sub-divide the train-set into train / validation sets
-- return indicies for the training and testing sets
-- we will later sub-divide the training-set into train & val sets
function splitMalwareDataTrainTest(labels,pTrain,pTest)
local pos = {}
local neg = {}
local nPrograms = labels:size(1)--allData.program:size(1)
-- record the incidies of all the pos/neg i.e. malware/benign examples
for i = 1,nPrograms do
if labels[i] == 1 then
table.insert(pos,i)
else
table.insert(neg,i)
end
end
print(#pos,#neg)
-- record all the positive and negative indicies
-- shuffle the data
-- take the first X% of pos and first x% of pos for training
local trainInds = {}
local testInds = {}
local indsPos = torch.randperm(#pos)
local indsNeg = torch.randperm(#neg)
local nPosTrain = torch.floor(#pos * pTrain)
local nNegTrain = torch.floor(#neg * pTrain)
local nPosTest = #pos - nPosTrain
local nNegTest = #neg - nNegTrain
print('splitting dataset')
print('nPosTrain',nPosTrain,'nNegTrain',nNegTrain,'pos/neg ',nPosTrain / (nPosTrain+nNegTrain))
print('nPosTest',nPosTest,'nNegTest',nNegTest,'pos/neg ',nPosTest / (nPosTest+nNegTest))
for i = 1,nPosTrain do
table.insert(trainInds,pos[indsPos[i]])
end
for i = 1,nNegTrain do
table.insert(trainInds,neg[indsNeg[i]])
end
for i = 1,nPosTest do
table.insert(testInds,pos[indsPos[nPosTrain + i]])
end
for i = 1,nNegTest do
table.insert(testInds,neg[indsNeg[nNegTrain + i]])
end
-- ratio used to weight the classes during training. Deals with
-- the unbalanced number of examples for each class
local posNegRatio = nPosTrain / (nPosTrain + nNegTrain)
return trainInds,testInds,posNegRatio
end
-- return indicies for the train,val and testing sets
function splitMalwareDataTrainValTest(labels,metaData)
local pTrain = 0.8
local pVal = 0.1
local pTest = 0.1
local testInds = metaData.testInds
local pos = {}
local neg = {}
local nPrograms = labels:size(1)--allData.program:size(1)
print('nPrograms ',nPrograms)
-- record the incidies of all the pos/neg i.e. malware/benign examples
for i = 1,nPrograms do
if labels[i] == 1 then
table.insert(pos,i)
else
table.insert(neg,i)
end
end
local posTrainVal = {}
local negTrainVal = {}
-- record the incidies of all the pos/neg i.e. malware/benign examples in the training-set
for i = 1,#metaData.trainInds do
if labels[metaData.trainInds[i]] == 1 then
table.insert(posTrainVal,metaData.trainInds[i])
else
table.insert(negTrainVal,metaData.trainInds[i])
end
end
print(#pos,#neg)
print(#posTrainVal,#negTrainVal)
-- record all the positive and negative indicies
-- shuffle the data
-- take the first X% of pos and first x% of pos for training
local trainInds = {}
local valInds = {}
local indsPos = torch.randperm(#posTrainVal)
local indsNeg = torch.randperm(#negTrainVal)
local nPosTrain = torch.floor(#pos * pTrain)
local nNegTrain = torch.floor(#neg * pTrain)
local nPosVal = #posTrainVal - nPosTrain
local nNegVal = #negTrainVal - nNegTrain
local nPosTest = #pos - (nPosTrain + nPosVal)
local nNegTest = #neg - (nNegTrain + nNegVal)
print('splitting dataset')
print('nPosTrain',nPosTrain,'nNegTrain',nNegTrain)
print('nPosVal',nPosVal,'nNegVal',nNegVal)
print('nPosTest',nPosTest,'nNegTest',nNegTest)
for i = 1,nPosTrain do
table.insert(trainInds,posTrainVal[indsPos[i]])
end
for i = 1,nNegTrain do
table.insert(trainInds,negTrainVal[indsNeg[i]])
end
for i = 1,nPosVal do
table.insert(valInds,posTrainVal[indsPos[nPosTrain + i]])
end
for i = 1,nNegVal do
table.insert(valInds,negTrainVal[indsNeg[nNegTrain + i]])
end
-- for i = 1,nPosTest do
-- table.insert(testInds,pos[indsPos[nPosTrain + nPosVal + i]])
-- end
-- for i = 1,nNegTest do
-- table.insert(testInds,neg[indsNeg[nNegTrain + nNegVal + i]])
-- end
-- ratio used to weight the classes during training. Deals with
-- the unbalanced number of examples for each class
local posNegRatio = nPosTrain / (nPosTrain + nNegTrain)
-- check there is no overlap between train / val / test sets
local sanity = torch.zeros(nPrograms)
for i = 1,#trainInds do
sanity[trainInds[i]] = sanity[trainInds[i]] + 1
end
for i = 1,#testInds do
sanity[testInds[i]] = sanity[testInds[i]] + 1
end
for i = 1,#valInds do
sanity[valInds[i]] = sanity[valInds[i]] + 1
end
print('train/val/test check',torch.min(sanity),torch.max(sanity),torch.sum(sanity),nPrograms)
if not (torch.min(sanity) == 1) or not (torch.max(sanity) == 1) or not (torch.sum(sanity) == nPrograms) then
-- stop if this happens
error('overlap between training / validation and testing sets')
end
return trainInds,valInds,testInds,posNegRatio
end