-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinteract.py
47 lines (39 loc) · 1.48 KB
/
interact.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
def main():
import config
from model import load_model
model = load_model(config.model_path+'_final')
while not model:
config.model_path = input('valid model: ')
model = load_model()
from data import load_data, split_data
d = load_data()
d, _ = split_data(d)
# from random import shuffle
# shuffle(d)
#d = d[:config.hm_output_file]
d = [d[8]] # [8,10,13,14]]
config.polyphony = False
for i,seq in enumerate(d):
from model import respond_to
seq = respond_to(model, seq[:1])
seq = [t.detach() for t in seq]
if config.use_gpu:
seq = [t.cpu() for t in seq]
seq = [t.numpy() for t in seq]
from data import note_reverse_dict, convert_to_midi
seq_converted = []
for timestep in seq:
if config.polyphony:
t_converted = ''
for i,e in enumerate(timestep[0]):
if e>config.pick_threshold:
t_converted += note_reverse_dict[i%12]+str(int(i/12)+config.min_octave) if i!=config.out_size-1 else 'R'
t_converted += ','
t_converted = t_converted[:-1] if len(t_converted) else 'R'
else:
i = timestep[0].argmax()
t_converted = note_reverse_dict[i%12]+str(int(i/12)+config.min_octave)
seq_converted.append(t_converted)
convert_to_midi(seq_converted).show()
if __name__ == '__main__':
main()