-
Notifications
You must be signed in to change notification settings - Fork 1
/
inference.py
58 lines (43 loc) · 1.7 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import argparse
import torch
import torch.nn as nn
import librosa
from torch import Tensor
from data.data_loader import load_audio
from vocabulary import label_to_string, load_label
from models.search import GreedySearch
def parse_audio(audio_path: str, audio_extension: str = 'pcm') -> Tensor:
sound = load_audio(audio_path, extension=audio_extension)
melspectrogram = librosa.feature.melspectrogram(
sound,
sr=16000,
n_mels=80,
n_fft=320,
hop_length=160
)
log_melspectrogram = librosa.amplitude_to_db(melspectrogram)
log_melspectrogram = torch.FloatTensor(log_melspectrogram)
return log_melspectrogram
parser = argparse.ArgumentParser(description='inference')
parser.add_argument('--model_path', type=str, default='')
parser.add_argument('--audio_path', type=str, default='')
parser.add_argument('--label_path', type=str, default='')
parser.add_argument('--eos_id', type=int, default=2)
parser.add_argument('--blank_id', type=int, default=1999)
args = parser.parse_args()
use_cuda = args.device and torch.cuda.is_available()
device = torch.device('cuda' if use_cuda else 'cpu')
feature = parse_audio(args.audio_path)
input_length = torch.LongTensor([feature.size(1)])
target = torch.LongTensor(1, 120)
y_hats = None
char2id, id2char = load_label(args.label_path, args.blank_id)
greedy_search = GreedySearch(device)
model = torch.load(args.model_path, map_location=device)
if isinstance(model, nn.DataParallel):
model = model.module
model.eval()
y_hats = greedy_search(model, feature.unsqueeze(0).to(device), input_length.to(device), target)
y_hats = y_hats.squeeze(0)
sentence = label_to_string(args.eos_id, args.blank_id, y_hats, id2char)
print(sentence)