-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathtrain.py
43 lines (33 loc) · 1.15 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import network
import data
import preprocessing
import numpy as np
from keras.callbacks import ModelCheckpoint
from tqdm import tqdm
import config
from keras.callbacks import CSVLogger
import os
def init():
""" Train a Neural Network to generate music """
network_input, network_output, n_vocab, _, _ = data.prepare_sequences()
net = network.Network(n_vocab)
model = net.model
if not os.path.exists("outputs"):
os.makedirs("outputs/")
os.makedirs("outputs/weights")
train(model, network_input, network_output)
def train(model, network_input, network_output):
""" train the neural network """
filepath = "outputs/weights/weights-{loss:.4f}.hdf5"
checkpoint = ModelCheckpoint(
filepath,
monitor='val_loss',
verbose=0,
save_best_only=True,
mode='min'
)
csv_logger = CSVLogger('outputs/train_log.csv', append=True, separator=';')
callbacks_list = [checkpoint, csv_logger]
model.fit(network_input, network_output, epochs=config.NUMBER_EPOCHS, batch_size=config.BATCH_SIZE, callbacks=callbacks_list, validation_split=0.10)
if __name__ == '__main__':
init()