Skip to content

Commit

Permalink
evaluate_tflite: Fix shared Queue
Browse files Browse the repository at this point in the history
Also dump output to a file

Signed-off-by: Li Li <eggonlea@msn.com>
  • Loading branch information
eggonlea committed Jun 12, 2019
1 parent 94df405 commit 944f39c
Showing 1 changed file with 20 additions and 4 deletions.
24 changes: 20 additions & 4 deletions evaluate_tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import os

from six.moves import zip, range
from multiprocessing import JoinableQueue, Pool, Process, Queue, cpu_count
from multiprocessing import JoinableQueue, Pool, Process, Queue, cpu_count, Manager
from deepspeech import Model

from util.evaluate_tools import process_decode_result, calculate_report
Expand Down Expand Up @@ -41,15 +41,18 @@ def tflite_worker(model, alphabet, lm, trie, queue_in, queue_out, gpu_mask):
while True:
msg = queue_in.get()

fin = wave.open(msg['filename'], 'rb')
filename = msg['filename']
wavname = os.path.splitext(os.path.basename(filename))[0]
fin = wave.open(filename, 'rb')
fs = fin.getframerate()
audio = np.frombuffer(fin.readframes(fin.getnframes()), np.int16)
audio_length = fin.getnframes() * (1/16000)
fin.close()

decoded = ds.stt(audio, fs)

queue_out.put({'prediction': decoded, 'ground_truth': msg['transcript']})
queue_out.put({'wav': wavname, 'prediction': decoded, 'ground_truth': msg['transcript']})
print(queue_out.qsize(), end='\r')
queue_in.task_done()

def main():
Expand All @@ -68,8 +71,9 @@ def main():
help='Number of processes to spawn, defaulting to number of CPUs')
args = parser.parse_args()

manager = Manager()
work_todo = JoinableQueue() # this is where we are going to store input data
work_done = Queue() # this where we are gonna push them out
work_done = manager.Queue() # this where we are gonna push them out

processes = []
for i in range(args.proc):
Expand All @@ -79,27 +83,39 @@ def main():

print([x.name for x in processes])

wavlist = []
ground_truths = []
predictions = []
losses = []

with open(args.csv, 'r') as csvfile:
csvreader = csv.DictReader(csvfile)
count = 0
for row in csvreader:
count += 1
work_todo.put({'filename': row['wav_filename'], 'transcript': row['transcript']})
print('Totally %d work todo\n' % count)
work_todo.join()
print('\nTotally %d work done' % work_done.qsize())

while (not work_done.empty()):
msg = work_done.get()
losses.append(0.0)
ground_truths.append(msg['ground_truth'])
predictions.append(msg['prediction'])
wavlist.append(msg['wav'])

wer, cer, samples = calculate_report(ground_truths, predictions, losses)
mean_loss = np.mean(losses)

print('Test - WER: %f, CER: %f, loss: %f' %
(wer, cer, mean_loss))

with open(args.csv + '.txt', 'w') as ftxt:
with open(args.csv + '.out', 'w') as fout:
for wav,txt,out in zip(wavlist, ground_truths, predictions):
ftxt.write('%s %s\n' % (wav, txt))
fout.write('%s %s\n' % (wav, out))

if __name__ == '__main__':
main()

0 comments on commit 944f39c

Please sign in to comment.