-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdemo_predict.py
41 lines (26 loc) · 1.08 KB
/
demo_predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import sys
import torch
from parser import get_argument_parser
from config import Configuration
import preprocess as pp
from predict import generate_seed_lyrics, predict
from postprocess import postprocess
def main(args):
'''Example script to run prediction on a pre-trained model with sample lyrics dataset'''
c = Configuration()
if args.artist:
c.set_artist(args.artist)
print("Artist:", c.artist.replace("_", " ").title())
lyrics_dataset = pp.read_lyrics_files(c.path)
dictionary = torch.load(open(c.dictionary_path, 'rb'))
print("Vocabulary size: ", len(dictionary))
print("----------------------------")
tokenized = pp.tokenize(lyrics_dataset)
seed_lyrics = generate_seed_lyrics(tokenized, c.window_size, args.censored)
model = torch.load(open(c.model_path, 'rb'))
predicted_lyrics = predict(model, seed_lyrics, dictionary, num_words=args.words, topk=c.predict_topk)
predicted_lyrics = postprocess(predicted_lyrics, args.censored)
print(predicted_lyrics)
if __name__ == "__main__":
parser = get_argument_parser()
main(parser.parse_args(sys.argv[1:]))