-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsongGen.py
69 lines (55 loc) · 2.01 KB
/
songGen.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import torch
from glob import glob
from tqdm import tqdm
from midi_utils import *
from bard import RubberBardLSTMFC
from bard import RubberBardLSTMFC2
from bard import RubberBardFCFCFC
import torch.nn.functional as F
from numpy import random
path = glob.glob("./pt_files/*.pt")
print(path[0])
# folder = 'zelda_original'
# folder = 'megaman'
# folder = 'gerudo'
#folder = 'DK'
folder = 'zelda'
# Generate total vocab.
tkMsgsList, numUniques, tokenToMsg, msgToToken = generateInputFromSongs(folder)
############
#Ensure these are the same as was used to generate the weights file!
k = 20
CONTEXT_SIZE = k
EMBEDDING_DIM = 10
vocab = numUniques
num_layers = 3
dropout = 0.3
batch_size = k
#################
song_len = 1000
title = 'Prediction_song' + folder
seed_notes = tkMsgsList[0][:batch_size]
seed = torch.tensor(seed_notes)
#model = RubberBardLSTMFC2(vocab, EMBEDDING_DIM, CONTEXT_SIZE, num_layers, dropout, batch_size=k)
# model = RubberBardFCFCFC(vocab, EMBEDDING_DIM, CONTEXT_SIZE, num_layers, dropout, batch_size)
# model = RubberBardLSTMFC(vocab, EMBEDDING_DIM, CONTEXT_SIZE, num_layers, dropout, batch_size=k)
# model.load_state_dict(torch.load(path[0]), strict=False)?
#model.load_state_dict(torch.load("pt_files/Practice21_Batch20-3Layer0p3.0dropLSTMFC_zelda_Epoch72.pt"), strict=False)
model = torch.load("pt_files/Practice21_Batch20-3Layer0p3.0dropLSTMFC_zelda_Epoch72Model.pt")
model.eval()
pred = [0]
#Here we essentially go through the training process on the seed notes, but do not backprop. this generates predictions without altering the model.
for i in range(song_len):
print(seed)
out = model(seed)
log_prob = F.log_softmax(out.view([1, numUniques]), dim=1)
values, indices = log_prob[0].max(0)
new_pred = indices.item()
if new_pred == pred[-1]:
new_pred =random.randint(0, vocab)
print("Randomizing...", new_pred)
else:
print("Predicting......", new_pred)
pred.append(new_pred)
seed = torch.cat((seed[1:], torch.tensor([new_pred])))
make_midi(pred,tokenToMsg, title)