-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
93 lines (75 loc) · 2.85 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from __future__ import print_function
from __future__ import division
from __future__ import absolute_import
import argparse
import json
import keras
import numpy as np
import os
import random
import time
import network
import load
import util
MAX_EPOCHS = 100
def make_save_dir(dirname, experiment_name):
start_time = str(int(time.time())) + '-' + str(random.randrange(1000))
save_dir = os.path.join(dirname, experiment_name, start_time)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
return save_dir
def get_filename_for_saving(save_dir):
return os.path.join(save_dir,
"{val_loss:.3f}-{val_acc:.3f}-{epoch:03d}-{loss:.3f}-{acc:.3f}.hdf5")
def train(args, params):
print("Loading training set...")
train = load.load_dataset(params['train']) #train = ecgs, labels
print("Loading dev set...")
dev = load.load_dataset(params['dev']) # dev = ecgs, labels
print("Building preprocessor...")
preproc = load.Preproc(*train)
print("Training size: " + str(len(train[0])) + " examples.")
print("Dev size: " + str(len(dev[0])) + " examples.")
save_dir = make_save_dir(params['save_dir'], args.experiment)
util.save(preproc, save_dir)
params.update({
"input_shape": [None, 1],
"num_categories": len(preproc.classes)
})
model = network.build_network(**params)
stopping = keras.callbacks.EarlyStopping(patience=8)
reduce_lr = keras.callbacks.ReduceLROnPlateau(
factor=0.1,
patience=2,
min_lr=params["learning_rate"] * 0.001)
checkpointer = keras.callbacks.ModelCheckpoint(
filepath=get_filename_for_saving(save_dir),
save_best_only=False)
batch_size = params.get("batch_size", 32)
if params.get("generator", False):
train_gen = load.data_generator(batch_size, preproc, *train)
dev_gen = load.data_generator(batch_size, preproc, *dev)
model.fit_generator(
train_gen,
steps_per_epoch=int(len(train[0]) / batch_size),
epochs=MAX_EPOCHS,
validation_data=dev_gen,
validation_steps=int(len(dev[0]) / batch_size),
callbacks=[checkpointer, reduce_lr, stopping])
else:
train_x, train_y = preproc.process(*train)
dev_x, dev_y = preproc.process(*dev)
model.fit(
train_x, train_y,
batch_size=batch_size,
epochs=MAX_EPOCHS,
validation_data=(dev_x, dev_y),
callbacks=[checkpointer, reduce_lr, stopping])
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("config_file", help="path to config file")
parser.add_argument("--experiment", "-e", help="tag with experiment name",
default="default")
args = parser.parse_args()
params = json.load(open(args.config_file, 'r'))
train(args, params)