Skip to content
This repository has been archived by the owner on Oct 13, 2022. It is now read-only.

Use L.fst instead of LG.fst; bug fixes. #6

Merged
merged 1 commit into from
Nov 13, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 45 additions & 37 deletions egs/librispeech/asr/simple_v1/train_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,29 +28,29 @@
from wav2letter import Wav2Letter


def create_decoding_graph(texts, graph, symbols):
def create_decoding_graph(texts, L, symbols):
fsas = []
for text in texts:
filter_text = [
i if i in symbols._sym2id else '<UNK>' for i in text.split(' ')
]
word_ids = [symbols.get(i) for i in filter_text]
fsa = k2.linear_fsa(word_ids)
#print("linear fsa is ", fsa)
print("linear fsa is ", fsa)
fsa = k2.arc_sort(fsa)
#print("linear fsa, arc-sorted, is ", fsa)
print("linear fsa, arc-sorted, is ", fsa)
print("begin")
print(k2.is_arc_sorted(k2.get_properties(fsa)))
decoding_graph = k2.intersect(fsa, graph).invert_()
print("end")
#print("decoding graph is ", decoding_graph)
decoding_graph = k2.intersect(fsa, L).invert_()
print("linear fsa, composed, is ", fsa)
print("decoding graph is ", decoding_graph)
decoding_graph = k2.add_epsilon_self_loops(decoding_graph)
#print("decoding graph with self-loops is ", decoding_graph)
print("decoding graph with self-loops is ", decoding_graph)
fsas.append(decoding_graph)
return k2.create_fsa_vec(fsas)


def get_objf(batch, model, device, graph, symbols, training, optimizer=None):
def get_objf(batch, model, device, L, symbols, training, optimizer=None):
feature = batch['features']
supervisions = batch['supervisions']
supervision_segments = torch.stack(
Expand All @@ -74,7 +74,7 @@ def get_objf(batch, model, device, graph, symbols, training, optimizer=None):
nnet_output = nnet_output.permute(0, 2, 1) # now nnet_output is [N, T, C]

# TODO(haowen): create decoding graph at the beginning of training
decoding_graph = create_decoding_graph(texts, graph, symbols)
decoding_graph = create_decoding_graph(texts, L, symbols)
decoding_graph.to_(device)
decoding_graph.scores.requires_grad_(False)
#print(nnet_output.shape)
Expand Down Expand Up @@ -102,29 +102,29 @@ def get_objf(batch, model, device, graph, symbols, training, optimizer=None):
return total_objf, total_frames


def get_validation_objf(dataloader, model, device, graph, symbols):
def get_validation_objf(dataloader, model, device, L, symbols):
total_objf = 0.
total_frames = 0. # for display only

model.eval()

for batch_idx, batch in enumerate(dataloader):
objf, frames = get_objf(batch, model, device, graph, symbols, False)
objf, frames = get_objf(batch, model, device, L, symbols, False)
total_objf += objf
total_frames += frames

return total_objf, total_frames


def train_one_epoch(dataloader, valid_dataloader, model, device, graph,
def train_one_epoch(dataloader, valid_dataloader, model, device, L,
symbols, optimizer, current_epoch, num_epochs):
total_objf = 0.
total_frames = 0.

model.train()
for batch_idx, batch in enumerate(dataloader):
curr_batch_objf, curr_batch_frames = get_objf(batch, model, device,
graph, symbols, True,
L, symbols, True,
optimizer)

total_objf += curr_batch_objf
Expand All @@ -145,12 +145,12 @@ def train_one_epoch(dataloader, valid_dataloader, model, device, graph,
curr_batch_frames,
))

if valid_dataloader and batch_idx % 1000 == 0:
if valid_dataloader is not None and batch_idx % 1000 == 0:
total_valid_objf, total_valid_frames = get_validation_objf(
dataloader=valid_dataloader,
model=model,
device=device,
graph=graph,
L=L,
symbols=symbols)
model.train()
logging.info(
Expand All @@ -165,28 +165,36 @@ def main():
with open(lang_dir + '/words.txt') as f:
symbol_table = k2.SymbolTable.from_str(f.read())

if not os.path.exists(lang_dir + '/LG.pt'):
print("Loading L.fst.txt")
with open(lang_dir + '/L.fst.txt') as f:
L = k2.Fsa.from_openfst(f.read(), acceptor=False)
print("Loading G.fsa.txt")
with open(lang_dir + '/G.fsa.txt') as f:
G = k2.Fsa.from_openfst(f.read(), acceptor=True)
print("Arc-sorting L...")
L = k2.arc_sort(L.invert_())
G = k2.arc_sort(G)
print(k2.is_arc_sorted(k2.get_properties(L)))
print(k2.is_arc_sorted(k2.get_properties(G)))
print("Intersecting L and G")
graph = k2.intersect(L, G)
graph = k2.arc_sort(graph)
print(k2.is_arc_sorted(k2.get_properties(graph)))
torch.save(graph.as_dict(), lang_dir + '/LG.pt')
else:
d = torch.load(lang_dir + '/LG.pt')
print("Loading pre-prepared LG")
graph = k2.Fsa.from_dict(d)

## This commented code created LG. We don't need that there.
## There were problems with disambiguation symbols; the G has
## disambiguation symbols which L.fst doesn't support.
# if not os.path.exists(lang_dir + '/LG.pt'):
# print("Loading L.fst.txt")
# with open(lang_dir + '/L.fst.txt') as f:
# L = k2.Fsa.from_openfst(f.read(), acceptor=False)
# print("Loading G.fsa.txt")
# with open(lang_dir + '/G.fsa.txt') as f:
# G = k2.Fsa.from_openfst(f.read(), acceptor=True)
# print("Arc-sorting L...")
# L = k2.arc_sort(L.invert_())
# G = k2.arc_sort(G)
# print(k2.is_arc_sorted(k2.get_properties(L)))
# print(k2.is_arc_sorted(k2.get_properties(G)))
# print("Intersecting L and G")
# graph = k2.intersect(L, G)
# graph = k2.arc_sort(graph)
# print(k2.is_arc_sorted(k2.get_properties(graph)))
# torch.save(graph.as_dict(), lang_dir + '/LG.pt')
# else:
# d = torch.load(lang_dir + '/LG.pt')
# print("Loading pre-prepared LG")
# graph = k2.Fsa.from_dict(d)

print("Loading L.fst.txt")
with open(lang_dir + '/L.fst.txt') as f:
L = k2.Fsa.from_openfst(f.read(), acceptor=False)
L = k2.arc_sort(L.invert_())

# load dataset
feature_dir = 'exp/data1'
Expand Down Expand Up @@ -250,7 +258,7 @@ def main():
valid_dataloader=valid_dl,
model=model,
device=device,
graph=graph,
L=L,
symbols=symbol_table,
optimizer=optimizer,
current_epoch=epoch,
Expand Down