From 691f6792f1d78c8bf5f5284bbe9327e570d304d1 Mon Sep 17 00:00:00 2001 From: Haowen Date: Fri, 13 Nov 2020 20:22:43 +0800 Subject: [PATCH] fix some bugs/issues of the training scripts --- egs/librispeech/asr/simple_v1/README.md | 129 +-------- egs/librispeech/asr/simple_v1/model.py | 3 + egs/librispeech/asr/simple_v1/prepare.py | 18 +- egs/librispeech/asr/simple_v1/run.sh | 2 +- egs/librispeech/asr/simple_v1/train.py | 127 ++++---- egs/librispeech/asr/simple_v1/train_fast.py | 302 -------------------- egs/librispeech/asr/simple_v1/wav2letter.py | 108 ------- 7 files changed, 88 insertions(+), 601 deletions(-) delete mode 100755 egs/librispeech/asr/simple_v1/train_fast.py delete mode 100755 egs/librispeech/asr/simple_v1/wav2letter.py diff --git a/egs/librispeech/asr/simple_v1/README.md b/egs/librispeech/asr/simple_v1/README.md index 175368a3..6b4d9d91 100644 --- a/egs/librispeech/asr/simple_v1/README.md +++ b/egs/librispeech/asr/simple_v1/README.md @@ -1,132 +1,9 @@ +### k2_librispeech -## k2_librispeech -An example of how to build G and L FST for K2. - -Most scripts of this example are copied from Kaldi. +An example of how to use k2 and lhotse to train a CTC acoustic model. ### Run scripts -```bash -$ ./run.sh - -$ ls data/lang_nosp - -G.fsa.txt -L.fst.txt -L_disambig.fst.txt -oov.int -oov.txt -phones -phones.txt -words.txt -``` - -### Load L, G into K2 -```python -import k2, _k2 - - -with open('data/lang_nosp/L.fst.txt') as f: - s = f.read() - -Lfst = k2.Fsa.from_openfst(s, acceptor=False) - -with open('data/lang_nosp/G.fsa.txt') as f: - s = f.read() -Gfsa = k2.Fsa.from_openfst(s, acceptor=True) -``` - -### An example of G building -The `toy.arpa` file: -```plain -\data\ -ngram 1=5 -ngram 2=6 -ngram 3=1 - -\1-grams: --2.348754 --99 -1.070027 --4.214113 A -0.5964623 --4.255245 B -0.3214741 --4.20255 C -0.2937318 - -\2-grams: --4.284099 A -0.1969815 --1.100091 A --2.839235 A B -0.1747991 --2.838903 A C -0.5100551 --1.104238 B --1.251644 C - -\3-grams: --0.1605104 A C B - -\end\ -``` - -Build G fst: -```bash -$ local/arpa2fst.py toy.arpa - -0 1 -5.408205947510138 -2 0 -1.070027 -0 3 A -9.703353773992418 -3 0 -0.5964623 -0 4 B -9.79806370403745 -4 0 -0.3214741 -0 5 C -9.676728982562127 -5 0 -0.2937318 -2 6 A -9.864502494310699 -6 3 -0.1969815 -3 1 -2.5330531375369127 -3 7 B -6.53758018650695 -7 4 -0.1747991 -3 8 C -6.536815728256077 -8 5 -0.5100551 -4 1 -2.5426019579175594 -5 1 -2.8820168161354394 -8 9 B -0.36958885431051147 -1 -``` +$ ./run.sh -Draw it by Graphviz: -``` -digraph FST { -rankdir = LR; -size = "8.5,11"; -label = ""; -center = 1; -ranksep = "0.4"; -nodesep = "0.25"; -0 [label = "0", shape = circle, style = bold, fontsize = 14] - 0 -> 1 [label = "/-5.4082", fontsize = 14]; - 0 -> 3 [label = "A/-9.7034", fontsize = 14]; - 0 -> 4 [label = "B/-9.7981", fontsize = 14]; - 0 -> 5 [label = "C/-9.6767", fontsize = 14]; -1 [label = "1", shape = doublecircle, style = solid, fontsize = 14] -2 [label = "2", shape = circle, style = solid, fontsize = 14] - 2 -> 0 [label = "/-1.07", fontsize = 14]; - 2 -> 6 [label = "A/-9.8645", fontsize = 14]; -3 [label = "3", shape = circle, style = solid, fontsize = 14] - 3 -> 0 [label = "/-0.59646", fontsize = 14]; - 3 -> 1 [label = "/-2.5331", fontsize = 14]; - 3 -> 7 [label = "B/-6.5376", fontsize = 14]; - 3 -> 8 [label = "C/-6.5368", fontsize = 14]; -4 [label = "4", shape = circle, style = solid, fontsize = 14] - 4 -> 0 [label = "/-0.32147", fontsize = 14]; - 4 -> 1 [label = "/-2.5426", fontsize = 14]; -5 [label = "5", shape = circle, style = solid, fontsize = 14] - 5 -> 0 [label = "/-0.29373", fontsize = 14]; - 5 -> 1 [label = "/-2.882", fontsize = 14]; -6 [label = "6", shape = circle, style = solid, fontsize = 14] - 6 -> 3 [label = "/-0.19698", fontsize = 14]; -7 [label = "7", shape = circle, style = solid, fontsize = 14] - 7 -> 4 [label = "/-0.1748", fontsize = 14]; -8 [label = "8", shape = circle, style = solid, fontsize = 14] - 8 -> 5 [label = "/-0.51006", fontsize = 14]; - 8 -> 9 [label = "B/-0.36959", fontsize = 14]; -9 [label = "9", shape = circle, style = solid, fontsize = 14] -} -``` diff --git a/egs/librispeech/asr/simple_v1/model.py b/egs/librispeech/asr/simple_v1/model.py index 77fc9fca..e2eb5b43 100755 --- a/egs/librispeech/asr/simple_v1/model.py +++ b/egs/librispeech/asr/simple_v1/model.py @@ -1,6 +1,9 @@ from torch import Tensor from torch import nn +# Copyright (c) 2020 Xiaomi Corporation (authors: Daniel Povey, Haowen Qiu) +# Apache 2.0 + class Model(nn.Module): """ diff --git a/egs/librispeech/asr/simple_v1/prepare.py b/egs/librispeech/asr/simple_v1/prepare.py index d0971db6..95758a74 100755 --- a/egs/librispeech/asr/simple_v1/prepare.py +++ b/egs/librispeech/asr/simple_v1/prepare.py @@ -1,5 +1,8 @@ #!/usr/bin/env python3 +# Copyright (c) 2020 Xiaomi Corporation (authors: Junbo Zhang, Haowen Qiu) +# Apache 2.0 + import os from concurrent.futures import ProcessPoolExecutor from pathlib import Path @@ -17,7 +20,7 @@ print("Parts we will prepare: ", dataset_parts) corpus_dir = '/home/storage04/zhuangweiji/data/open-source-data/librispeech/LibriSpeech' -output_dir = 'exp/data1' +output_dir = 'exp/data' librispeech_manifests = prepare_librispeech(corpus_dir, dataset_parts, output_dir) @@ -34,7 +37,6 @@ torch.set_num_threads(1) torch.set_num_interop_threads(1) -num_jobs = 1 for partition, manifests in librispeech_manifests.items(): print(partition) with LilcomFilesWriter(f'{output_dir}/feats_{partition}' @@ -44,17 +46,7 @@ supervisions=manifests['supervisions']).compute_and_store_features( extractor=Fbank(), storage=storage, - augmenter=augmenter if 'train' in partition else None, + augment_fn=augmenter if 'train' in partition else None, executor=ex) librispeech_manifests[partition]['cuts'] = cut_set cut_set.to_json(output_dir + f'/cuts_{partition}.json.gz') - -cuts_train = SpeechRecognitionDataset( - librispeech_manifests['train-clean-100']['cuts']) -cuts_test = SpeechRecognitionDataset( - librispeech_manifests['test-clean']['cuts']) - -sample = cuts_train[0] -print('Transcript:', sample['text']) -print('Supervisions mask:', sample['supervisions_mask']) -print('Feature matrix:', sample.load_features()) diff --git a/egs/librispeech/asr/simple_v1/run.sh b/egs/librispeech/asr/simple_v1/run.sh index a5c1d20a..0d86de3c 100755 --- a/egs/librispeech/asr/simple_v1/run.sh +++ b/egs/librispeech/asr/simple_v1/run.sh @@ -38,5 +38,5 @@ if [ $stage -le 5 ]; then fi if [ $stage -le 6 ]; then - python3 ./train_fast.py + python3 ./train.py fi diff --git a/egs/librispeech/asr/simple_v1/train.py b/egs/librispeech/asr/simple_v1/train.py index b8f42774..607d8925 100755 --- a/egs/librispeech/asr/simple_v1/train.py +++ b/egs/librispeech/asr/simple_v1/train.py @@ -1,5 +1,8 @@ #!/usr/bin/env python3 +# Copyright (c) 2020 Xiaomi Corporation (authors: Daniel Povey, Haowen Qiu) +# Apache 2.0 + import logging import os import sys @@ -25,24 +28,24 @@ from common import save_checkpoint from common import save_training_info from common import setup_logger -from wav2letter import Wav2Letter +from model import Model + -def create_decoding_graph(texts, graph, symbols): - fsas = [] +def create_decoding_graph(texts, L, symbols): + word_ids_list = [] for text in texts: filter_text = [ i if i in symbols._sym2id else '' for i in text.split(' ') ] word_ids = [symbols.get(i) for i in filter_text] - fsa = k2.linear_fsa(word_ids) - fsa = k2.arc_sort(fsa) - decoding_graph = k2.intersect(fsa, graph).invert_() - decoding_graph = k2.add_epsilon_self_loops(decoding_graph) - fsas.append(decoding_graph) - return k2.create_fsa_vec(fsas) + word_ids_list.append(word_ids) + fsa = k2.linear_fsa(word_ids_list) + decoding_graph = k2.intersect(fsa, L).invert_() + decoding_graph = k2.add_epsilon_self_loops(decoding_graph) + return decoding_graph -def get_objf(batch, model, device, graph, symbols, training, optimizer=None): +def get_objf(batch, model, device, L, symbols, training, optimizer=None): feature = batch['features'] supervisions = batch['supervisions'] supervision_segments = torch.stack( @@ -50,12 +53,11 @@ def get_objf(batch, model, device, graph, symbols, training, optimizer=None): supervisions['num_frames']), 1).to(torch.int32) texts = supervisions['text'] assert feature.ndim == 3 - #print(feature.shape) - #print(supervision_segments[:, 1] + supervision_segments[:, 2]) + # print(supervision_segments[:, 1] + supervision_segments[:, 2]) + feature = feature.to(device) # at entry, feature is [N, T, C] feature = feature.permute(0, 2, 1) # now feature is [N, C, T] - feature = feature.to(device) if training: nnet_output = model(feature) else: @@ -65,20 +67,18 @@ def get_objf(batch, model, device, graph, symbols, training, optimizer=None): # nnet_output is [N, C, T] nnet_output = nnet_output.permute(0, 2, 1) # now nnet_output is [N, T, C] - # TODO(haowen): create decoding graph at the beginning of training - decoding_graph = create_decoding_graph(texts, graph, symbols) + # TODO(haowen): create decoding graph (and cache) at the beginning of training + decoding_graph = create_decoding_graph(texts, L, symbols) decoding_graph.to_(device) decoding_graph.scores.requires_grad_(False) - #print(nnet_output.shape) dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments) - #dense_fsa_vec.scores.requires_grad_(True) assert decoding_graph.is_cuda() assert decoding_graph.device == device assert nnet_output.device == device - #print(nnet_output.get_device()) - #print(dense_fsa_vec) - target_graph = k2.intersect_dense_pruned(decoding_graph, dense_fsa_vec, 10, - 10000, 0) + # TODO(haowen): with a small `beam`, we may get empty `target_graph`, + # thus `tot_scores` will be `inf`. Definitely we need to handle this later. + target_graph = k2.intersect_dense_pruned(decoding_graph, dense_fsa_vec, + 10000, 10000, 0) tot_scores = -k2.get_tot_scores(target_graph, True, False).sum() if training: optimizer.zero_grad() @@ -93,35 +93,34 @@ def get_objf(batch, model, device, graph, symbols, training, optimizer=None): return total_objf, total_frames -def get_validation_objf(dataloader, model, device, graph, symbols): +def get_validation_objf(dataloader, model, device, L, symbols): total_objf = 0. total_frames = 0. # for display only model.eval() for batch_idx, batch in enumerate(dataloader): - objf, frames = get_objf(batch, model, device, graph, symbols, False) + objf, frames = get_objf(batch, model, device, L, symbols, False) total_objf += objf total_frames += frames return total_objf, total_frames -def train_one_epoch(dataloader, valid_dataloader, model, device, graph, - symbols, optimizer, current_epoch, num_epochs): +def train_one_epoch(dataloader, valid_dataloader, model, device, L, symbols, + optimizer, current_epoch, num_epochs): total_objf = 0. total_frames = 0. model.train() for batch_idx, batch in enumerate(dataloader): - curr_batch_objf, curr_batch_frames = get_objf(batch, model, device, - graph, symbols, True, - optimizer) + curr_batch_objf, curr_batch_frames = get_objf(batch, model, device, L, + symbols, True, optimizer) total_objf += curr_batch_objf total_frames += curr_batch_frames - if batch_idx % 100 == 0: + if batch_idx % 1 == 0: logging.info( 'processing batch {}, current epoch is {}/{} ' 'global average objf: {:.6f} over {} ' @@ -136,12 +135,12 @@ def train_one_epoch(dataloader, valid_dataloader, model, device, graph, curr_batch_frames, )) - if valid_dataloader and batch_idx % 1000 == 0: + if batch_idx > 0 and batch_idx % 1000 == 0: total_valid_objf, total_valid_frames = get_validation_objf( dataloader=valid_dataloader, model=model, device=device, - graph=graph, + L=L, symbols=symbols) model.train() logging.info( @@ -153,46 +152,71 @@ def train_one_epoch(dataloader, valid_dataloader, model, device, graph, def main(): # load L, G, symbol_table lang_dir = 'data/lang_nosp' + symbol_table = k2.SymbolTable.from_file(lang_dir + '/words.txt') + + ## This commented code created LG. We don't need that there. + ## There were problems with disambiguation symbols; the G has + ## disambiguation symbols which L.fst doesn't support. + # if not os.path.exists(lang_dir + '/LG.pt'): + # print("Loading L.fst.txt") + # with open(lang_dir + '/L.fst.txt') as f: + # L = k2.Fsa.from_openfst(f.read(), acceptor=False) + # print("Loading G.fsa.txt") + # with open(lang_dir + '/G.fsa.txt') as f: + # G = k2.Fsa.from_openfst(f.read(), acceptor=True) + # print("Arc-sorting L...") + # L = k2.arc_sort(L.invert_()) + # G = k2.arc_sort(G) + # print(k2.is_arc_sorted(k2.get_properties(L))) + # print(k2.is_arc_sorted(k2.get_properties(G))) + # print("Intersecting L and G") + # graph = k2.intersect(L, G) + # graph = k2.arc_sort(graph) + # print(k2.is_arc_sorted(k2.get_properties(graph))) + # torch.save(graph.as_dict(), lang_dir + '/LG.pt') + # else: + # d = torch.load(lang_dir + '/LG.pt') + # print("Loading pre-prepared LG") + # graph = k2.Fsa.from_dict(d) + + print("Loading L.fst.txt") with open(lang_dir + '/L.fst.txt') as f: L = k2.Fsa.from_openfst(f.read(), acceptor=False) - - with open(lang_dir + '/G.fsa.txt') as f: - G = k2.Fsa.from_openfst(f.read(), acceptor=True) - - with open(lang_dir + '/words.txt') as f: - symbol_table = k2.SymbolTable.from_str(f.read()) - L = k2.arc_sort(L.invert_()) - G = k2.arc_sort(G) - graph = k2.intersect(L, G) - graph = k2.arc_sort(graph) # load dataset - feature_dir = 'exp/data1' + feature_dir = 'exp/data' + print("About to get train cuts") cuts_train = CutSet.from_json(feature_dir + '/cuts_train-clean-100.json.gz') - + #cuts_train = CutSet.from_json(feature_dir + '/cuts_dev-clean.json.gz') + print("About to get dev cuts") cuts_dev = CutSet.from_json(feature_dir + '/cuts_dev-clean.json.gz') + print("About to create train dataset") train = K2SpeechRecognitionIterableDataset(cuts_train, shuffle=True) + print("About to create dev dataset") validate = K2SpeechRecognitionIterableDataset(cuts_dev, shuffle=False) + print("About to create train dataloader") train_dl = torch.utils.data.DataLoader(train, batch_size=None, num_workers=1) + print("About to create dev dataloader") valid_dl = torch.utils.data.DataLoader(validate, batch_size=None, num_workers=1) - dir = 'exp' - setup_logger('{}/log/log-train'.format(dir)) + exp_dir = 'exp' + setup_logger('{}/log/log-train'.format(exp_dir)) if not torch.cuda.is_available(): logging.error('No GPU detected!') sys.exit(-1) + print("About to create model") device_id = 0 device = torch.device('cuda', device_id) - model = Wav2Letter(num_classes=364, input_type='mfcc', num_features=40) + model = Model(num_features=40, num_classes=364) model.to(device) learning_rate = 0.001 @@ -200,8 +224,8 @@ def main(): num_epochs = 10 best_objf = 100000 best_epoch = start_epoch - best_model_path = os.path.join(dir, 'best_model.pt') - best_epoch_info_filename = os.path.join(dir, 'best-epoch-info') + best_model_path = os.path.join(exp_dir, 'best_model.pt') + best_epoch_info_filename = os.path.join(exp_dir, 'best-epoch-info') optimizer = optim.Adam(model.parameters(), lr=learning_rate, @@ -219,7 +243,7 @@ def main(): valid_dataloader=valid_dl, model=model, device=device, - graph=graph, + L=L, symbols=symbol_table, optimizer=optimizer, current_epoch=epoch, @@ -241,13 +265,14 @@ def main(): best_epoch=best_epoch) # we always save the model for every epoch - model_path = os.path.join(dir, 'epoch-{}.pt'.format(epoch)) + model_path = os.path.join(exp_dir, 'epoch-{}.pt'.format(epoch)) save_checkpoint(filename=model_path, model=model, epoch=epoch, learning_rate=curr_learning_rate, objf=objf) - epoch_info_filename = os.path.join(dir, 'epoch-{}-info'.format(epoch)) + epoch_info_filename = os.path.join(exp_dir, + 'epoch-{}-info'.format(epoch)) save_training_info(filename=epoch_info_filename, model_path=model_path, current_epoch=epoch, diff --git a/egs/librispeech/asr/simple_v1/train_fast.py b/egs/librispeech/asr/simple_v1/train_fast.py deleted file mode 100755 index 6b15d7c7..00000000 --- a/egs/librispeech/asr/simple_v1/train_fast.py +++ /dev/null @@ -1,302 +0,0 @@ -#!/usr/bin/env python3 - -import logging -import os -import sys -import warnings -from concurrent.futures import ProcessPoolExecutor -from pathlib import Path - -import numpy as np -import torch -import torch.optim as optim -from torch.nn.utils import clip_grad_value_ -import torchaudio -import torchaudio.models -import k2 - -from lhotse import CutSet, Fbank, LilcomFilesWriter, WavAugmenter -from lhotse.dataset import SpeechRecognitionDataset -from lhotse.dataset.speech_recognition import K2DataLoader, K2SpeechRecognitionDataset, \ - K2SpeechRecognitionIterableDataset, concat_cuts -from lhotse.recipes.librispeech import download_and_untar, prepare_librispeech, dataset_parts_full - -from common import load_checkpoint -from common import save_checkpoint -from common import save_training_info -from common import setup_logger -from wav2letter import Wav2Letter - - -def create_decoding_graph(texts, L, symbols): - fsas = [] - for text in texts: - filter_text = [ - i if i in symbols._sym2id else '' for i in text.split(' ') - ] - word_ids = [symbols.get(i) for i in filter_text] - fsa = k2.linear_fsa(word_ids) - print("linear fsa is ", fsa) - fsa = k2.arc_sort(fsa) - print("linear fsa, arc-sorted, is ", fsa) - print("begin") - print(k2.is_arc_sorted(k2.get_properties(fsa))) - decoding_graph = k2.intersect(fsa, L).invert_() - print("linear fsa, composed, is ", fsa) - print("decoding graph is ", decoding_graph) - decoding_graph = k2.add_epsilon_self_loops(decoding_graph) - print("decoding graph with self-loops is ", decoding_graph) - fsas.append(decoding_graph) - return k2.create_fsa_vec(fsas) - - -def get_objf(batch, model, device, L, symbols, training, optimizer=None): - feature = batch['features'] - supervisions = batch['supervisions'] - supervision_segments = torch.stack( - (supervisions['sequence_idx'], supervisions['start_frame'], - supervisions['num_frames']), 1).to(torch.int32) - texts = supervisions['text'] - assert feature.ndim == 3 - #print(feature.shape) - #print(supervision_segments[:, 1] + supervision_segments[:, 2]) - - # at entry, feature is [N, T, C] - feature = feature.permute(0, 2, 1) # now feature is [N, C, T] - feature = feature.to(device) - if training: - nnet_output = model(feature) - else: - with torch.no_grad(): - nnet_output = model(feature) - - # nnet_output is [N, C, T] - nnet_output = nnet_output.permute(0, 2, 1) # now nnet_output is [N, T, C] - - # TODO(haowen): create decoding graph at the beginning of training - decoding_graph = create_decoding_graph(texts, L, symbols) - decoding_graph.to_(device) - decoding_graph.scores.requires_grad_(False) - #print(nnet_output.shape) - dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments) - #dense_fsa_vec.scores.requires_grad_(True) - assert decoding_graph.is_cuda() - assert decoding_graph.device == device - assert nnet_output.device == device - #print(nnet_output.get_device()) - print(decoding_graph.arcs) - print(dense_fsa_vec.dense_fsa_vec) - target_graph = k2.intersect_dense_pruned(decoding_graph, dense_fsa_vec, 10, - 10000, 0) - tot_scores = -k2.get_tot_scores(target_graph, True, False).sum() - if training: - optimizer.zero_grad() - tot_scores.backward() - clip_grad_value_(model.parameters(), 5.0) - optimizer.step() - - objf = tot_scores.detach().cpu() - total_objf = objf.item() - total_frames = nnet_output.shape[0] - - return total_objf, total_frames - - -def get_validation_objf(dataloader, model, device, L, symbols): - total_objf = 0. - total_frames = 0. # for display only - - model.eval() - - for batch_idx, batch in enumerate(dataloader): - objf, frames = get_objf(batch, model, device, L, symbols, False) - total_objf += objf - total_frames += frames - - return total_objf, total_frames - - -def train_one_epoch(dataloader, valid_dataloader, model, device, L, - symbols, optimizer, current_epoch, num_epochs): - total_objf = 0. - total_frames = 0. - - model.train() - for batch_idx, batch in enumerate(dataloader): - curr_batch_objf, curr_batch_frames = get_objf(batch, model, device, - L, symbols, True, - optimizer) - - total_objf += curr_batch_objf - total_frames += curr_batch_frames - - if batch_idx % 100 == 0: - logging.info( - 'processing batch {}, current epoch is {}/{} ' - 'global average objf: {:.6f} over {} ' - 'frames, current batch average objf: {:.6f} over {} frames'. - format( - batch_idx, - current_epoch, - num_epochs, - total_objf / total_frames, - total_frames, - curr_batch_objf / curr_batch_frames, - curr_batch_frames, - )) - - if valid_dataloader is not None and batch_idx % 1000 == 0: - total_valid_objf, total_valid_frames = get_validation_objf( - dataloader=valid_dataloader, - model=model, - device=device, - L=L, - symbols=symbols) - model.train() - logging.info( - 'Validation average objf: {:.6f} over {} frames'.format( - total_valid_objf / total_valid_frames, total_valid_frames)) - return total_objf - - -def main(): - # load L, G, symbol_table - lang_dir = 'data/lang_nosp' - with open(lang_dir + '/words.txt') as f: - symbol_table = k2.SymbolTable.from_str(f.read()) - - - ## This commented code created LG. We don't need that there. - ## There were problems with disambiguation symbols; the G has - ## disambiguation symbols which L.fst doesn't support. - # if not os.path.exists(lang_dir + '/LG.pt'): - # print("Loading L.fst.txt") - # with open(lang_dir + '/L.fst.txt') as f: - # L = k2.Fsa.from_openfst(f.read(), acceptor=False) - # print("Loading G.fsa.txt") - # with open(lang_dir + '/G.fsa.txt') as f: - # G = k2.Fsa.from_openfst(f.read(), acceptor=True) - # print("Arc-sorting L...") - # L = k2.arc_sort(L.invert_()) - # G = k2.arc_sort(G) - # print(k2.is_arc_sorted(k2.get_properties(L))) - # print(k2.is_arc_sorted(k2.get_properties(G))) - # print("Intersecting L and G") - # graph = k2.intersect(L, G) - # graph = k2.arc_sort(graph) - # print(k2.is_arc_sorted(k2.get_properties(graph))) - # torch.save(graph.as_dict(), lang_dir + '/LG.pt') - # else: - # d = torch.load(lang_dir + '/LG.pt') - # print("Loading pre-prepared LG") - # graph = k2.Fsa.from_dict(d) - - print("Loading L.fst.txt") - with open(lang_dir + '/L.fst.txt') as f: - L = k2.Fsa.from_openfst(f.read(), acceptor=False) - L = k2.arc_sort(L.invert_()) - - # load dataset - feature_dir = 'exp/data1' - print("About to get train cuts") - #cuts_train = CutSet.from_json(feature_dir + - # '/cuts_train-clean-100.json.gz') - cuts_train = CutSet.from_json(feature_dir + '/cuts_dev-clean.json.gz') - print("About to get dev cuts") - cuts_dev = CutSet.from_json(feature_dir + '/cuts_dev-clean.json.gz') - - print("About to create train dataset") - train = K2SpeechRecognitionIterableDataset(cuts_train, - max_frames=1000, - shuffle=True) - print("About to create dev dataset") - validate = K2SpeechRecognitionIterableDataset(cuts_dev, - max_frames=1000, - shuffle=False) - print("About to create train dataloader") - train_dl = torch.utils.data.DataLoader(train, - batch_size=None, - num_workers=1) - print("About to create dev dataloader") - valid_dl = torch.utils.data.DataLoader(validate, - batch_size=None, - num_workers=1) - - exp_dir = 'exp' - setup_logger('{}/log/log-train'.format(exp_dir)) - - if not torch.cuda.is_available(): - logging.error('No GPU detected!') - sys.exit(-1) - print("About to create model") - device_id = 0 - device = torch.device('cuda', device_id) - model = Wav2Letter(num_classes=364, input_type='mfcc', num_features=40) - model.to(device) - - learning_rate = 0.001 - start_epoch = 0 - num_epochs = 10 - best_objf = 100000 - best_epoch = start_epoch - best_model_path = os.path.join(exp_dir, 'best_model.pt') - best_epoch_info_filename = os.path.join(exp_dir, 'best-epoch-info') - - optimizer = optim.Adam(model.parameters(), - lr=learning_rate, - weight_decay=5e-4) - # optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=0.9) - - for epoch in range(start_epoch, num_epochs): - curr_learning_rate = learning_rate * pow(0.4, epoch) - for param_group in optimizer.param_groups: - param_group['lr'] = curr_learning_rate - - logging.info('epoch {}, learning rate {}'.format( - epoch, curr_learning_rate)) - objf = train_one_epoch(dataloader=train_dl, - valid_dataloader=valid_dl, - model=model, - device=device, - L=L, - symbols=symbol_table, - optimizer=optimizer, - current_epoch=epoch, - num_epochs=num_epochs) - if objf < best_objf: - best_objf = objf - best_epoch = epoch - save_checkpoint(filename=best_model_path, - model=model, - epoch=epoch, - learning_rate=curr_learning_rate, - objf=objf) - save_training_info(filename=best_epoch_info_filename, - model_path=best_model_path, - current_epoch=epoch, - learning_rate=curr_learning_rate, - objf=best_objf, - best_objf=best_objf, - best_epoch=best_epoch) - - # we always save the model for every epoch - model_path = os.path.join(exp_dir, 'epoch-{}.pt'.format(epoch)) - save_checkpoint(filename=model_path, - model=model, - epoch=epoch, - learning_rate=curr_learning_rate, - objf=objf) - epoch_info_filename = os.path.join(exp_dir, 'epoch-{}-info'.format(epoch)) - save_training_info(filename=epoch_info_filename, - model_path=model_path, - current_epoch=epoch, - learning_rate=curr_learning_rate, - objf=objf, - best_objf=best_objf, - best_epoch=best_epoch) - - logging.warning('Done') - - -if __name__ == '__main__': - main() diff --git a/egs/librispeech/asr/simple_v1/wav2letter.py b/egs/librispeech/asr/simple_v1/wav2letter.py deleted file mode 100755 index 7dd0f27e..00000000 --- a/egs/librispeech/asr/simple_v1/wav2letter.py +++ /dev/null @@ -1,108 +0,0 @@ -from torch import Tensor -from torch import nn - - -class Wav2Letter(nn.Module): - r"""Wav2Letter model architecture from the `Wav2Letter an End-to-End ConvNet-based Speech Recognition System`_. - - .. _Wav2Letter an End-to-End ConvNet-based Speech Recognition System: https://arxiv.org/abs/1609.03193 - - :math:`\text{padding} = \frac{\text{ceil}(\text{kernel} - \text{stride})}{2}` - - Args: - num_classes (int, optional): Number of classes to be classified. (Default: ``40``) - input_type (str, optional): Wav2Letter can use as input: ``waveform``, ``power_spectrum`` - or ``mfcc`` (Default: ``waveform``). - num_features (int, optional): Number of input features that the network will receive (Default: ``1``). - """ - - def __init__(self, - num_classes: int = 40, - input_type: str = "waveform", - num_features: int = 1) -> None: - super(Wav2Letter, self).__init__() - - #nn.BatchNorm1d(num_features=hidden_dim, affine=False) - - acoustic_num_features = 250 if input_type == "waveform" else num_features - acoustic_model = nn.Sequential( - nn.Conv1d(in_channels=acoustic_num_features, - out_channels=250, - kernel_size=3, - stride=1, - padding=1), nn.ReLU(inplace=True), - nn.Conv1d(in_channels=250, - out_channels=250, - kernel_size=3, - stride=1, - padding=1), nn.ReLU(inplace=True), - nn.Conv1d(in_channels=250, - out_channels=250, - kernel_size=3, - stride=1, - padding=1), nn.ReLU(inplace=True), - nn.Conv1d(in_channels=250, - out_channels=250, - kernel_size=3, - stride=1, - padding=1), nn.ReLU(inplace=True), - nn.Conv1d(in_channels=250, - out_channels=250, - kernel_size=3, - stride=1, - padding=1), nn.ReLU(inplace=True), - nn.Conv1d(in_channels=250, - out_channels=250, - kernel_size=3, - stride=1, - padding=1), nn.ReLU(inplace=True), - nn.Conv1d(in_channels=250, - out_channels=250, - kernel_size=3, - stride=1, - padding=1), nn.ReLU(inplace=True), - nn.Conv1d(in_channels=250, - out_channels=250, - kernel_size=3, - stride=1, - padding=1), nn.ReLU(inplace=True), - nn.Conv1d(in_channels=250, - out_channels=2000, - kernel_size=3, - stride=1, - padding=1), nn.ReLU(inplace=True), - nn.Conv1d(in_channels=2000, - out_channels=2000, - kernel_size=1, - stride=1, - padding=0), nn.ReLU(inplace=True), - nn.Conv1d(in_channels=2000, - out_channels=num_classes, - kernel_size=1, - stride=1, - padding=0), nn.ReLU(inplace=True)) - - if input_type == "waveform": - waveform_model = nn.Sequential( - nn.Conv1d(in_channels=num_features, - out_channels=250, - kernel_size=250, - stride=160, - padding=45), nn.ReLU(inplace=True)) - self.acoustic_model = nn.Sequential(waveform_model, acoustic_model) - - if input_type in ["power_spectrum", "mfcc"]: - self.acoustic_model = acoustic_model - - def forward(self, x: Tensor) -> Tensor: - r""" - Args: - x (torch.Tensor): Tensor of dimension (batch_size, num_features, input_length). - - Returns: - Tensor: Predictor tensor of dimension (batch_size, number_of_classes, input_length). - """ - - x = self.acoustic_model(x) - x = nn.functional.log_softmax(x, dim=1) - return x