From 066bce3029bbcb1bb547e8ca60df48391072f08d Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 13 Nov 2020 13:35:48 +0800 Subject: [PATCH] Use L.fst instead of LG.fst; bug fixes. --- egs/librispeech/asr/simple_v1/train_fast.py | 82 +++++++++++---------- 1 file changed, 45 insertions(+), 37 deletions(-) diff --git a/egs/librispeech/asr/simple_v1/train_fast.py b/egs/librispeech/asr/simple_v1/train_fast.py index 33cd02dc..6b15d7c7 100755 --- a/egs/librispeech/asr/simple_v1/train_fast.py +++ b/egs/librispeech/asr/simple_v1/train_fast.py @@ -28,7 +28,7 @@ from wav2letter import Wav2Letter -def create_decoding_graph(texts, graph, symbols): +def create_decoding_graph(texts, L, symbols): fsas = [] for text in texts: filter_text = [ @@ -36,21 +36,21 @@ def create_decoding_graph(texts, graph, symbols): ] word_ids = [symbols.get(i) for i in filter_text] fsa = k2.linear_fsa(word_ids) - #print("linear fsa is ", fsa) + print("linear fsa is ", fsa) fsa = k2.arc_sort(fsa) - #print("linear fsa, arc-sorted, is ", fsa) + print("linear fsa, arc-sorted, is ", fsa) print("begin") print(k2.is_arc_sorted(k2.get_properties(fsa))) - decoding_graph = k2.intersect(fsa, graph).invert_() - print("end") - #print("decoding graph is ", decoding_graph) + decoding_graph = k2.intersect(fsa, L).invert_() + print("linear fsa, composed, is ", fsa) + print("decoding graph is ", decoding_graph) decoding_graph = k2.add_epsilon_self_loops(decoding_graph) - #print("decoding graph with self-loops is ", decoding_graph) + print("decoding graph with self-loops is ", decoding_graph) fsas.append(decoding_graph) return k2.create_fsa_vec(fsas) -def get_objf(batch, model, device, graph, symbols, training, optimizer=None): +def get_objf(batch, model, device, L, symbols, training, optimizer=None): feature = batch['features'] supervisions = batch['supervisions'] supervision_segments = torch.stack( @@ -74,7 +74,7 @@ def get_objf(batch, model, device, graph, symbols, training, optimizer=None): nnet_output = nnet_output.permute(0, 2, 1) # now nnet_output is [N, T, C] # TODO(haowen): create decoding graph at the beginning of training - decoding_graph = create_decoding_graph(texts, graph, symbols) + decoding_graph = create_decoding_graph(texts, L, symbols) decoding_graph.to_(device) decoding_graph.scores.requires_grad_(False) #print(nnet_output.shape) @@ -102,21 +102,21 @@ def get_objf(batch, model, device, graph, symbols, training, optimizer=None): return total_objf, total_frames -def get_validation_objf(dataloader, model, device, graph, symbols): +def get_validation_objf(dataloader, model, device, L, symbols): total_objf = 0. total_frames = 0. # for display only model.eval() for batch_idx, batch in enumerate(dataloader): - objf, frames = get_objf(batch, model, device, graph, symbols, False) + objf, frames = get_objf(batch, model, device, L, symbols, False) total_objf += objf total_frames += frames return total_objf, total_frames -def train_one_epoch(dataloader, valid_dataloader, model, device, graph, +def train_one_epoch(dataloader, valid_dataloader, model, device, L, symbols, optimizer, current_epoch, num_epochs): total_objf = 0. total_frames = 0. @@ -124,7 +124,7 @@ def train_one_epoch(dataloader, valid_dataloader, model, device, graph, model.train() for batch_idx, batch in enumerate(dataloader): curr_batch_objf, curr_batch_frames = get_objf(batch, model, device, - graph, symbols, True, + L, symbols, True, optimizer) total_objf += curr_batch_objf @@ -145,12 +145,12 @@ def train_one_epoch(dataloader, valid_dataloader, model, device, graph, curr_batch_frames, )) - if valid_dataloader and batch_idx % 1000 == 0: + if valid_dataloader is not None and batch_idx % 1000 == 0: total_valid_objf, total_valid_frames = get_validation_objf( dataloader=valid_dataloader, model=model, device=device, - graph=graph, + L=L, symbols=symbols) model.train() logging.info( @@ -165,28 +165,36 @@ def main(): with open(lang_dir + '/words.txt') as f: symbol_table = k2.SymbolTable.from_str(f.read()) - if not os.path.exists(lang_dir + '/LG.pt'): - print("Loading L.fst.txt") - with open(lang_dir + '/L.fst.txt') as f: - L = k2.Fsa.from_openfst(f.read(), acceptor=False) - print("Loading G.fsa.txt") - with open(lang_dir + '/G.fsa.txt') as f: - G = k2.Fsa.from_openfst(f.read(), acceptor=True) - print("Arc-sorting L...") - L = k2.arc_sort(L.invert_()) - G = k2.arc_sort(G) - print(k2.is_arc_sorted(k2.get_properties(L))) - print(k2.is_arc_sorted(k2.get_properties(G))) - print("Intersecting L and G") - graph = k2.intersect(L, G) - graph = k2.arc_sort(graph) - print(k2.is_arc_sorted(k2.get_properties(graph))) - torch.save(graph.as_dict(), lang_dir + '/LG.pt') - else: - d = torch.load(lang_dir + '/LG.pt') - print("Loading pre-prepared LG") - graph = k2.Fsa.from_dict(d) + ## This commented code created LG. We don't need that there. + ## There were problems with disambiguation symbols; the G has + ## disambiguation symbols which L.fst doesn't support. + # if not os.path.exists(lang_dir + '/LG.pt'): + # print("Loading L.fst.txt") + # with open(lang_dir + '/L.fst.txt') as f: + # L = k2.Fsa.from_openfst(f.read(), acceptor=False) + # print("Loading G.fsa.txt") + # with open(lang_dir + '/G.fsa.txt') as f: + # G = k2.Fsa.from_openfst(f.read(), acceptor=True) + # print("Arc-sorting L...") + # L = k2.arc_sort(L.invert_()) + # G = k2.arc_sort(G) + # print(k2.is_arc_sorted(k2.get_properties(L))) + # print(k2.is_arc_sorted(k2.get_properties(G))) + # print("Intersecting L and G") + # graph = k2.intersect(L, G) + # graph = k2.arc_sort(graph) + # print(k2.is_arc_sorted(k2.get_properties(graph))) + # torch.save(graph.as_dict(), lang_dir + '/LG.pt') + # else: + # d = torch.load(lang_dir + '/LG.pt') + # print("Loading pre-prepared LG") + # graph = k2.Fsa.from_dict(d) + + print("Loading L.fst.txt") + with open(lang_dir + '/L.fst.txt') as f: + L = k2.Fsa.from_openfst(f.read(), acceptor=False) + L = k2.arc_sort(L.invert_()) # load dataset feature_dir = 'exp/data1' @@ -250,7 +258,7 @@ def main(): valid_dataloader=valid_dl, model=model, device=device, - graph=graph, + L=L, symbols=symbol_table, optimizer=optimizer, current_epoch=epoch,