-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdemo_train.py
58 lines (36 loc) · 1.55 KB
/
demo_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import sys
import torch
from torch.utils.data import DataLoader
from parser import get_argument_parser
from config import Configuration
import preprocess as pp
from models import LyricPredictor
from train import train
from predict import generate_seed_lyrics, predict
from postprocess import postprocess
def main(args):
'''Example script to train a model on the sample lyrics dataset'''
c = Configuration()
if args.artist:
c.set_artist(args.artist)
print("Hyperparameters: ", c)
print("Loading data from path: ", c.path)
lyrics_dataset = pp.read_lyrics_files(c.path)
tokenized = pp.tokenize(lyrics_dataset)
x, y, dictionary = pp.preprocess(tokenized, c.window_size)
training_data = DataLoader(list(zip(x,y)), batch_size=c.train_batch_size, shuffle=True)
model = LyricPredictor(len(dictionary), c.output_size)
print("Training model...")
model, _ , _ = train(model=model, training_data=training_data, num_epochs=c.num_epochs, lr=c.lr, grad_norm=c.grad_max_norm)
print("Saving model: ", c.model_path)
torch.save(model, c.model_path)
print("Saving dictionary: ", c.dictionary_path)
torch.save(dictionary, c.dictionary_path)
print("Generating lyrics...")
seed_lyrics = generate_seed_lyrics(tokenized, c.window_size, args.censored)
predicted_lyrics = predict(model, seed_lyrics, dictionary, num_words=args.words, topk=c.predict_topk)
predicted_lyrics = postprocess(predicted_lyrics, args.censored)
print(predicted_lyrics)
if __name__ == "__main__":
parser = get_argument_parser()
main(parser.parse_args(sys.argv[1:]))