-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathgenerate.py
63 lines (51 loc) · 1.83 KB
/
generate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import argparse
import numpy
import librosa
import chainer
from WaveGlow import Glow
from utils import Preprocess
import params
parser = argparse.ArgumentParser()
parser.add_argument('--input', '-i', help='Input file')
parser.add_argument('--output', '-o', default='Result.wav', help='output file')
parser.add_argument('--model', '-m', help='Snapshot of trained model')
parser.add_argument('--var', '-v', type=float, default=0.6 ** 2,
help='Variance of Gaussian distribution')
parser.add_argument('--gpu', '-g', type=int, default=-1,
help='GPU ID (negative value indicates CPU)')
args = parser.parse_args()
if args.gpu != [-1]:
chainer.cuda.set_max_workspace_size(2 * 512 * 1024 * 1024)
chainer.global_config.autotune = True
# set data
path = args.input
# preprocess
n = 1 # batchsize; now suporrts only 1
inputs = Preprocess(
params.sr, params.n_fft, params.hop_length, params.n_mels, params.fmin,
params.fmax, None)(path)
_, condition = inputs
condition = numpy.expand_dims(condition, axis=0)
# make model
glow = Glow(
params.hop_length, params.n_mels, 1,
params.squeeze_factor, params.n_flows, params.n_layers,
params.wn_channel, params.early_every, params.early_size,
params.var)
# load trained parameter
chainer.serializers.load_npz(args.model, glow, 'updater/model:main/')
if args.gpu >= 0:
use_gpu = True
chainer.cuda.get_device_from_id(args.gpu).use()
else:
use_gpu = False
# forward
if use_gpu:
condition = chainer.cuda.to_gpu(condition, device=args.gpu)
glow.to_gpu(device=args.gpu)
condition = chainer.Variable(condition)
with chainer.using_config('enable_backprop', False):
output = glow.generate(condition, args.var)
output = chainer.cuda.to_cpu(output.array)
output = numpy.squeeze(output)
librosa.output.write_wav(args.output, output, params.sr)