-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathevaluate_specific.py
130 lines (96 loc) · 4.34 KB
/
evaluate_specific.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
from models.resnet import ResNet
from evaluation.metrics import F1Metric
from datasets.ptbxldataset import PTBXLDataset
from datasets.arr10000dataset import Arr10000Dataset
#from datasets.datagenerator import DataGenerator
from datasets.mitbihardataset import MITBIHARDataset
from datasets.savvydataset import SavvyDataset
import tensorflow as tf
import os
from sklearn.metrics import confusion_matrix
import pickle
#classes = "rhythm"
classes = ["N", "S", "V", "F", "Q"]
num_classes = 5
db_name = "savvydb"
choice = "static"
eval_p = "specific"
dat = SavvyDataset(db_name)
## Warning: for now balance = True and False are treated the same and saved to same files (i.e. overwritten)
dat.generate_train_set(eval_p,choice,True)
dat.generate_val_set(eval_p,choice,False)
dat.generate_test_set(eval_p,choice,False)
for patient_id in dat.specific_patients:
# experiments/mitdb/static/specificpatient/201/models/"
exp_path = "experiments-dp"+os.sep+db_name+os.sep+choice+os.sep+eval_p+"patient"+os.sep+patient_id
if not os.path.exists(exp_path):
os.makedirs(exp_path)
os.mkdir(exp_path+os.sep+"models")
os.mkdir(exp_path+os.sep+"results")
train = dat.load_dataset(eval_p,choice, 'train', patient_id )
val = dat.load_dataset(eval_p,choice, 'val', patient_id )
test = dat.load_dataset(eval_p,choice, 'test', patient_id )
if len(train) != 2:
print("generate crossval splits")
# we need to do random n-crossval splits for val and test
if len(val) != 2:
# we need to do random split for val
print("this dataset does not exist")
if len(test) != 2:
print("prob. error")
## shouldn't ever happen
# options for now are: 1 defined splits for patient-specific, 2 totally random splits for intra-patient, 3 random val split for inter-patient
'''dataset=PTBXLDataset(classes)
dataset.examine_database()
print("DONE")
crossval_split_id = 0
# Datasets
partition = dataset.get_crossval_split(crossval_split_id) #Get from dataset class # IDs (dict)
labels = dataset.get_labels() #Get from dataset class # Labels (dict)
# Generators
training_generator = DataGenerator(partition['train'], labels, **params)
validation_generator = DataGenerator(partition['validation'], labels, **params)
'''
train = (train[0], tf.keras.utils.to_categorical(train[1], num_classes=len(classes)))
val = (val[0] , tf.keras.utils.to_categorical(val[1], num_classes=len(classes)))
test = (test[0], tf.keras.utils.to_categorical(test[1], num_classes=len(classes)))
model = ResNet(num_outputs=num_classes, blocks=[1,1], filters=[32, 64], kernel_size=[15,15], dropout=0.1)
inputs = tf.keras.layers.Input((200,1,), dtype='float32')
m1 = tf.keras.Model(inputs=inputs, outputs=model.call(inputs))
#m1.summary()
opt = tf.keras.optimizers.Adam(lr=0.0001)
m1.compile(optimizer=opt,
#tf.keras.optimizers.Adam(beta_1=0.9, beta_2=0.98, epsilon=1e-9),
#loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
loss='categorical_crossentropy',
metrics='acc')
es = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=4)
log_f1 = F1Metric(train=train,validation=val, path=exp_path+os.sep+"models")
m1.fit(x=train[0],y=train[1], validation_data = val, callbacks = [es, log_f1], epochs = 100)
y_pred = m1.predict(test[0])
cm = confusion_matrix(test[1].argmax(axis=1), y_pred.argmax(axis=1),labels=range(num_classes))
output = open(exp_path+os.sep+'CM_test.pkl', 'wb')
pickle.dump(cm, output)
output.close()
y_pred = m1.predict(val[0])
cm = confusion_matrix(val[1].argmax(axis=1), y_pred.argmax(axis=1),labels=range(num_classes))
output = open(exp_path+os.sep+'CM_val.pkl', 'wb')
pickle.dump(cm, output)
output.close()
'''
# initialize the weights of the model
input_shape, _ = tf.compat.v1.data.get_output_shapes(train_data)
inputs = build_input_tensor_from_shape(input_shape, dtype=input_dtype, ignore_batch_dim=True)
model(inputs)
checkpoint = CustomCheckpoint(
filepath=str(args.job_dir / 'epoch_{epoch:02d}' / 'model.weights'),
data=(validation_data, val['y']),
score_fn=f1,
save_best_only=False,
verbose=1)
logger = tf.keras.callbacks.CSVLogger(str(args.job_dir / 'history.csv'))
model.fit_generator(generator=training_generator,
validation_data=validation_generator,
use_multiprocessing=True,
workers=6)
'''