-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathstump_ensemble_categorical.py
100 lines (84 loc) · 3.73 KB
/
stump_ensemble_categorical.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
'''
- models a stump ensemble that considers categorical variables and supports l0 certification
'''
import numpy as np
class StumpEnsembleCategorical():
'''models a stump ensemble supporting categorical features'''
def __init__(self):
self.predictions = {}
def train(self, X_train, y_train, categorical_indices, w_train=None, lr=1.0):
'''trains a classifier for every stump independently'''
self.predictions = {}
if w_train is None:
w_train = np.ones(len(X_train))
for cat_idx in categorical_indices:
predictions_current = {}
cat_num = max(X_train[:, cat_idx]) + 1
for i in range(cat_num):
flag = X_train[:, cat_idx] == i
predictions_current[i] = (int(((np.sum(y_train[flag]*w_train[flag]) / np.sum(w_train[flag])) >= 0.5)*2)-1)*(lr/2.0)+0.5
self.predictions[cat_idx] = predictions_current
def predict_soft_(self, X):
'''soft prediction'''
pred = 0.0
for i, cat in enumerate(self.predictions.keys()):
if int(X[cat]) in self.predictions[cat]:
pred += self.predictions[cat][int(X[cat])]
else:
pred += 0.5 # not in training data
return pred, len(self.predictions.keys())
def get_perturbations_(self, X, y):
'''returns score difference through perturbations, worst are first'''
perturbations = []
for i, cat in enumerate(self.predictions.keys()):
perturbation = 0.0
if int(X[cat]) in self.predictions[cat]:
pred = self.predictions[cat][int(X[cat])]
else:
pred = 0.5
for j in self.predictions[cat].values():
if y == 1:
perturbation = min(perturbation, -pred + j)
elif y == 0:
perturbation = max(perturbation, -pred + j)
perturbations.append(perturbation)
if y == 0:
perturbations = sorted(perturbations, reverse=True)
else:
perturbations = sorted(perturbations)
return perturbations
def predict(self, X_test):
'''hard prediction'''
y_pred = []
for k in range(len(X_test)):
pred, normalizer = self.predict_soft_(X_test[k])
if pred/normalizer > 0.5:
y_pred.append(1)
else:
y_pred.append(0)
return y_pred
def certify(self, X_test, y_test, radius):
'''returns whether sampels are certifiably robust at l0 radius'''
y_cert = []
worst_case_predictions, normalizer = self.worst_case_predictions(X_test, y_test, radius)
for k in range(len(X_test)):
certifiable = 0
if y_test[k] == 1:
if worst_case_predictions[k] / normalizer > 0.5:
certifiable = 1
else:
if worst_case_predictions[k] / normalizer <= 0.5:
certifiable = 1
y_cert.append(certifiable)
return y_cert
def worst_case_predictions(self, X_test, y_test, radius):
'''considers the worst case predictions at l0 radius'''
y_worst_case_predictions = []
for k in range(len(X_test)):
pred, normalizer = self.predict_soft_(X_test[k])
worst_case_prediction = pred
worst_perturbations = self.get_perturbations_(X_test[k], y_test[k])
for r in range(radius):
worst_case_prediction += worst_perturbations[r]
y_worst_case_predictions.append(worst_case_prediction)
return y_worst_case_predictions, normalizer