From 863c5544ca1a652afeafc826d19ce55c03894b52 Mon Sep 17 00:00:00 2001 From: Li Li Date: Tue, 11 Jun 2019 13:19:04 -0700 Subject: [PATCH] evaluate_tflite: Fix shared Queue Also dump output to a file Fixed some trivial pylint issues at the same time Signed-off-by: Li Li --- evaluate_tflite.py | 40 +++++++++++++++++++++++++++++----------- 1 file changed, 29 insertions(+), 11 deletions(-) diff --git a/evaluate_tflite.py b/evaluate_tflite.py index 12aaf03d91..8e9c18ceef 100644 --- a/evaluate_tflite.py +++ b/evaluate_tflite.py @@ -6,14 +6,13 @@ import numpy as np import wave import csv -import sys import os from six.moves import zip, range -from multiprocessing import JoinableQueue, Pool, Process, Queue, cpu_count +from multiprocessing import JoinableQueue, Process, cpu_count, Manager from deepspeech import Model -from util.evaluate_tools import process_decode_result, calculate_report +from util.evaluate_tools import calculate_report r''' This module should be self-contained: @@ -41,15 +40,17 @@ def tflite_worker(model, alphabet, lm, trie, queue_in, queue_out, gpu_mask): while True: msg = queue_in.get() - fin = wave.open(msg['filename'], 'rb') + filename = msg['filename'] + wavname = os.path.splitext(os.path.basename(filename))[0] + fin = wave.open(filename, 'rb') fs = fin.getframerate() audio = np.frombuffer(fin.readframes(fin.getnframes()), np.int16) - audio_length = fin.getnframes() * (1/16000) fin.close() - + decoded = ds.stt(audio, fs) - - queue_out.put({'prediction': decoded, 'ground_truth': msg['transcript']}) + + queue_out.put({'wav': wavname, 'prediction': decoded, 'ground_truth': msg['transcript']}) + print(queue_out.qsize(), end='\r') # Update the current progress queue_in.task_done() def main(): @@ -66,10 +67,13 @@ def main(): help='Path to the CSV source file') parser.add_argument('--proc', required=False, default=cpu_count(), type=int, help='Number of processes to spawn, defaulting to number of CPUs') + parser.add_argument('--dump', required=False, action='store_true', default=False, + help='Dump the results as text file, with one line for each wav: "wav transcription"') args = parser.parse_args() + manager = Manager() work_todo = JoinableQueue() # this is where we are going to store input data - work_done = Queue() # this where we are gonna push them out + work_done = manager.Queue() # this where we are gonna push them out processes = [] for i in range(args.proc): @@ -79,27 +83,41 @@ def main(): print([x.name for x in processes]) + wavlist = [] ground_truths = [] predictions = [] losses = [] with open(args.csv, 'r') as csvfile: csvreader = csv.DictReader(csvfile) + count = 0 for row in csvreader: + count += 1 work_todo.put({'filename': row['wav_filename'], 'transcript': row['transcript']}) + print('Totally %d wav entries found in csv\n' % count) work_todo.join() + print('\nTotally %d wav file transcripted' % work_done.qsize()) - while (not work_done.empty()): + while not work_done.empty(): msg = work_done.get() losses.append(0.0) ground_truths.append(msg['ground_truth']) predictions.append(msg['prediction']) + wavlist.append(msg['wav']) - wer, cer, samples = calculate_report(ground_truths, predictions, losses) + wer, cer, _ = calculate_report(ground_truths, predictions, losses) mean_loss = np.mean(losses) print('Test - WER: %f, CER: %f, loss: %f' % (wer, cer, mean_loss)) + if args.dump: + with open(args.csv + '.txt', 'w') as ftxt, open(args.csv + '.out', 'w') as fout: + for wav, txt, out in zip(wavlist, ground_truths, predictions): + ftxt.write('%s %s\n' % (wav, txt)) + fout.write('%s %s\n' % (wav, out)) + print('Reference texts dumped to %s.txt' % args.csv) + print('Transcription dumped to %s.out' % args.csv) + if __name__ == '__main__': main()