From a85af3da49a87ada3fa9a12760cf5533783060ea Mon Sep 17 00:00:00 2001
From: Reuben Morais <reuben.morais@gmail.com>
Date: Wed, 10 Apr 2019 16:29:11 -0300
Subject: [PATCH 1/6] Do separate validation epochs if multiple input files are
 specified

---
 DeepSpeech.py   | 62 ++++++++++++++++++++++++++++---------------------
 util/feeding.py |  2 +-
 util/flags.py   |  1 -
 3 files changed, 36 insertions(+), 29 deletions(-)

diff --git a/DeepSpeech.py b/DeepSpeech.py
index 0e91b0f53f..3146d6a2a8 100755
--- a/DeepSpeech.py
+++ b/DeepSpeech.py
@@ -376,10 +376,9 @@ def train():
     train_init_op = iterator.make_initializer(train_set)
 
     if FLAGS.dev_files:
-        dev_set = create_dataset(FLAGS.dev_files.split(','),
-                                 batch_size=FLAGS.dev_batch_size,
-                                 cache_path=FLAGS.dev_cached_features_path)
-        dev_init_op = iterator.make_initializer(dev_set)
+        dev_csvs = FLAGS.dev_files.split(',')
+        dev_sets = [create_dataset([csv], batch_size=FLAGS.dev_batch_size) for csv in dev_csvs]
+        dev_init_ops = [iterator.make_initializer(dev_set) for dev_set in dev_sets]
 
     # Dropout
     dropout_rates = [tf.placeholder(tf.float32, name='dropout_{}'.format(i)) for i in range(6)]
@@ -425,6 +424,15 @@ def train():
 
     initializer = tf.global_variables_initializer()
 
+    # Disable progress logging if needed
+    if FLAGS.show_progressbar:
+        pbar_class = progressbar.ProgressBar
+        def log_progress(*args, **kwargs):
+            pass
+    else:
+        pbar_class = progressbar.NullBar
+        log_progress = log_info
+
     with tf.Session(config=Config.session_config) as session:
         log_debug('Session opened.')
 
@@ -445,7 +453,7 @@ def train():
                           ' - consider using load option "auto" or "init".' % FLAGS.load)
                 sys.exit(1)
 
-        def run_set(set_name, init_op):
+        def run_set(set_name, epoch, init_op, dataset=None):
             is_train = set_name == 'train'
             train_op = apply_gradient_op if is_train else []
             feed_dict = dropout_feed_dict if is_train else no_dropout_feed_dict
@@ -456,6 +464,7 @@ def run_set(set_name, init_op):
             step_summary_writer = step_summary_writers.get(set_name)
             checkpoint_time = time.time()
 
+            # Setup progress bar
             class LossWidget(progressbar.widgets.FormatLabel):
                 def __init__(self):
                     progressbar.widgets.FormatLabel.__init__(self, format='Loss: %(mean_loss)f')
@@ -464,12 +473,12 @@ def __call__(self, progress, data, **kwargs):
                     data['mean_loss'] = total_loss / step_count if step_count else 0.0
                     return progressbar.widgets.FormatLabel.__call__(self, progress, data, **kwargs)
 
-            if FLAGS.show_progressbar:
-                pbar = progressbar.ProgressBar(widgets=['Epoch {}'.format(epoch),
-                                                        ' | ', progressbar.widgets.Timer(),
-                                                        ' | Steps: ', progressbar.widgets.Counter(),
-                                                        ' | ', LossWidget()])
-                pbar.start()
+            prefix = 'Epoch {} | {:>10}'.format(epoch, 'Training' if is_train else 'Validation')
+            widgets = [' | ', progressbar.widgets.Timer(),
+                       ' | Steps: ', progressbar.widgets.Counter(),
+                       ' | ', LossWidget()]
+            suffix = ' | Dataset: {}'.format(dataset) if dataset else None
+            pbar = pbar_class(prefix=prefix, widgets=widgets, suffix=suffix, fd=sys.stdout).start()
 
             # Initialize iterator to the appropriate dataset
             session.run(init_op)
@@ -486,8 +495,7 @@ def __call__(self, progress, data, **kwargs):
                 total_loss += batch_loss
                 step_count += 1
 
-                if FLAGS.show_progressbar:
-                    pbar.update(step_count)
+                pbar.update(step_count)
 
                 step_summary_writer.add_summary(step_summary, current_step)
 
@@ -495,10 +503,9 @@ def __call__(self, progress, data, **kwargs):
                     checkpoint_saver.save(session, checkpoint_path, global_step=current_step)
                     checkpoint_time = time.time()
 
-            if FLAGS.show_progressbar:
-                pbar.finish()
-
-            return total_loss / step_count
+            pbar.finish()
+            mean_loss = total_loss / step_count if step_count > 0 else 0.0
+            return mean_loss
 
         log_info('STARTING Optimization')
         best_dev_loss = float('inf')
@@ -506,20 +513,21 @@ def __call__(self, progress, data, **kwargs):
         try:
             for epoch in range(FLAGS.epochs):
                 # Training
-                if not FLAGS.show_progressbar:
-                    log_info('Training epoch %d...' % epoch)
-                train_loss = run_set('train', train_init_op)
-                if not FLAGS.show_progressbar:
-                    log_info('Finished training epoch %d - loss: %f' % (epoch, train_loss))
+                log_progress('Training epoch %d...' % epoch)
+                train_loss = run_set('train', epoch, train_init_op)
+                log_progress('Finished training epoch %d - loss: %f' % (epoch, train_loss))
                 checkpoint_saver.save(session, checkpoint_path, global_step=global_step)
 
                 if FLAGS.dev_files:
                     # Validation
-                    if not FLAGS.show_progressbar:
-                        log_info('Validating epoch %d...' % epoch)
-                    dev_loss = run_set('dev', dev_init_op)
-                    if not FLAGS.show_progressbar:
-                        log_info('Finished validating epoch %d - loss: %f' % (epoch, dev_loss))
+                    dev_loss = 0.0
+                    for csv, init_op in zip(dev_csvs, dev_init_ops):
+                        log_progress('Validating epoch %d on %s...' % (epoch, csv))
+                        set_loss = run_set('dev', epoch, init_op, dataset=csv)
+                        dev_loss += set_loss
+                        log_progress('Finished validating epoch %d on %s - loss: %f' % (epoch, csv, set_loss))
+                    dev_loss = dev_loss / len(dev_csvs)
+
                     dev_losses.append(dev_loss)
 
                     if dev_loss < best_dev_loss:
diff --git a/util/feeding.py b/util/feeding.py
index 2f01c880c5..e15914ab89 100644
--- a/util/feeding.py
+++ b/util/feeding.py
@@ -63,7 +63,7 @@ def to_sparse_tuple(sequence):
     return indices, sequence, shape
 
 
-def create_dataset(csvs, batch_size, cache_path):
+def create_dataset(csvs, batch_size, cache_path=''):
     df = read_csvs(csvs)
     df.sort_values(by='wav_filesize', inplace=True)
 
diff --git a/util/flags.py b/util/flags.py
index a6b6386f67..f9d3bcd3f8 100644
--- a/util/flags.py
+++ b/util/flags.py
@@ -17,7 +17,6 @@ def create_flags():
     f.DEFINE_string('test_files', '', 'comma separated list of files specifying the dataset used for testing. Multiple files will get merged. If empty, the model will not be tested.')
 
     f.DEFINE_string('train_cached_features_path', '', 'comma separated list of files specifying the dataset used for training. multiple files will get merged')
-    f.DEFINE_string('dev_cached_features_path', '', 'comma separated list of files specifying the dataset used for validation. multiple files will get merged')
     f.DEFINE_string('test_cached_features_path', '', 'comma separated list of files specifying the dataset used for testing. multiple files will get merged')
 
     f.DEFINE_integer('feature_win_len', 32, 'feature extraction audio window length in milliseconds')

From 58e9b1a78e0bc5ae64cd7c14d3fa964ce1c3fad5 Mon Sep 17 00:00:00 2001
From: Reuben Morais <reuben.morais@gmail.com>
Date: Wed, 10 Apr 2019 16:52:04 -0300
Subject: [PATCH 2/6] Log total optimization time

---
 DeepSpeech.py | 3 +++
 1 file changed, 3 insertions(+)

diff --git a/DeepSpeech.py b/DeepSpeech.py
index 3146d6a2a8..7cb894583d 100755
--- a/DeepSpeech.py
+++ b/DeepSpeech.py
@@ -14,6 +14,7 @@
 import shutil
 import tensorflow as tf
 
+from datetime import datetime
 from ds_ctcdecoder import ctc_beam_search_decoder, Scorer
 from evaluate import evaluate
 from six.moves import zip, range
@@ -508,6 +509,7 @@ def __call__(self, progress, data, **kwargs):
             return mean_loss
 
         log_info('STARTING Optimization')
+        train_start_time = datetime.utcnow()
         best_dev_loss = float('inf')
         dev_losses = []
         try:
@@ -551,6 +553,7 @@ def __call__(self, progress, data, **kwargs):
                             break
         except KeyboardInterrupt:
             pass
+        log_info('FINISHED optimization in {}'.format(datetime.utcnow() - train_start_time))
     log_debug('Session closed.')
 
 

From 911a1ce4b17339f87f1db015267acdfe9b790680 Mon Sep 17 00:00:00 2001
From: Reuben Morais <reuben.morais@gmail.com>
Date: Wed, 10 Apr 2019 16:52:18 -0300
Subject: [PATCH 3/6] Do separate test epochs if multiple input files are
 specified

---
 evaluate.py   | 127 ++++++++++++++++++++++++++++----------------------
 util/flags.py |   1 -
 2 files changed, 70 insertions(+), 58 deletions(-)

diff --git a/evaluate.py b/evaluate.py
index 95ef4afc00..203bcfea79 100755
--- a/evaluate.py
+++ b/evaluate.py
@@ -45,12 +45,14 @@ def evaluate(test_csvs, create_model, try_loading):
                     FLAGS.lm_binary_path, FLAGS.lm_trie_path,
                     Config.alphabet)
 
-    test_set = create_dataset(test_csvs,
-                              batch_size=FLAGS.test_batch_size,
-                              cache_path=FLAGS.test_cached_features_path)
-    it = test_set.make_one_shot_iterator()
+    test_csvs = FLAGS.test_files.split(',')
+    test_sets = [create_dataset([csv], batch_size=FLAGS.test_batch_size) for csv in test_csvs]
+    iterator = tf.data.Iterator.from_structure(test_sets[0].output_types,
+                                               test_sets[0].output_shapes,
+                                               output_classes=test_sets[0].output_classes)
+    test_init_ops = [iterator.make_initializer(test_set) for test_set in test_sets]
 
-    (batch_x, batch_x_len), batch_y = it.get_next()
+    (batch_x, batch_x_len), batch_y = iterator.get_next()
 
     # One rate per layer
     no_dropout = [None] * 6
@@ -67,10 +69,16 @@ def evaluate(test_csvs, create_model, try_loading):
 
     tf.train.get_or_create_global_step()
 
-    with tf.Session(config=Config.session_config) as session:
-        # Create a saver using variables from the above newly created graph
-        saver = tf.train.Saver()
+    # Get number of accessible CPU cores for this process
+    try:
+        num_processes = cpu_count()
+    except NotImplementedError:
+        num_processes = 1
+
+    # Create a saver using variables from the above newly created graph
+    saver = tf.train.Saver()
 
+    with tf.Session(config=Config.session_config) as session:
         # Restore variables from training checkpoint
         loaded = try_loading(session, saver, 'best_dev_checkpoint', 'best validation')
         if not loaded:
@@ -79,70 +87,75 @@ def evaluate(test_csvs, create_model, try_loading):
             log_error('Checkpoint directory ({}) does not contain a valid checkpoint state.'.format(FLAGS.checkpoint_dir))
             exit(1)
 
-        logitses = []
-        losses = []
-        seq_lengths = []
-        ground_truths = []
+        def run_test(init_op, dataset):
+            logitses = []
+            losses = []
+            seq_lengths = []
+            ground_truths = []
 
-        print('Computing acoustic model predictions...')
-        bar = progressbar.ProgressBar(widgets=['Steps: ', progressbar.Counter(), ' | ', progressbar.Timer()])
+            bar = progressbar.ProgressBar(prefix='Computing acoustic model predictions | ',
+                                          widgets=['Steps: ', progressbar.Counter(), ' | ', progressbar.Timer()],
+                                          fd=sys.stdout).start()
 
-        step_count = 0
+            step_count = 0
 
-        # First pass, compute losses and transposed logits for decoding
-        while True:
-            try:
-                logits, loss_, lengths, transcripts = session.run([transposed, loss, batch_x_len, batch_y])
-            except tf.errors.OutOfRangeError:
-                break
+            # Initialize iterator to the appropriate dataset
+            session.run(init_op)
 
-            step_count += 1
-            bar.update(step_count)
+            # First pass, compute losses and transposed logits for decoding
+            while True:
+                try:
+                    logits, loss_, lengths, transcripts = session.run([transposed, loss, batch_x_len, batch_y])
+                except tf.errors.OutOfRangeError:
+                    break
 
-            logitses.append(logits)
-            losses.extend(loss_)
-            seq_lengths.append(lengths)
-            ground_truths.extend(sparse_tensor_value_to_texts(transcripts, Config.alphabet))
+                step_count += 1
+                bar.update(step_count)
 
-    bar.finish()
+                logitses.append(logits)
+                losses.extend(loss_)
+                seq_lengths.append(lengths)
+                ground_truths.extend(sparse_tensor_value_to_texts(transcripts, Config.alphabet))
 
-    predictions = []
+            bar.finish()
 
-    # Get number of accessible CPU cores for this process
-    try:
-        num_processes = cpu_count()
-    except NotImplementedError:
-        num_processes = 1
+            predictions = []
+
+            bar = progressbar.ProgressBar(max_value=step_count,
+                                          prefix='Decoding predictions | ',
+                                          fd=sys.stdout).start()
 
-    print('Decoding predictions...')
-    bar = progressbar.ProgressBar(max_value=step_count,
-                                  widget=progressbar.AdaptiveETA)
+            # Second pass, decode logits and compute WER and edit distance metrics
+            for logits, seq_length in bar(zip(logitses, seq_lengths)):
+                decoded = ctc_beam_search_decoder_batch(logits, seq_length, Config.alphabet, FLAGS.beam_width,
+                                                        num_processes=num_processes, scorer=scorer)
+                predictions.extend(d[0][1] for d in decoded)
 
-    # Second pass, decode logits and compute WER and edit distance metrics
-    for logits, seq_length in bar(zip(logitses, seq_lengths)):
-        decoded = ctc_beam_search_decoder_batch(logits, seq_length, Config.alphabet, FLAGS.beam_width,
-                                                num_processes=num_processes, scorer=scorer)
-        predictions.extend(d[0][1] for d in decoded)
+            distances = [levenshtein(a, b) for a, b in zip(ground_truths, predictions)]
 
-    distances = [levenshtein(a, b) for a, b in zip(ground_truths, predictions)]
+            wer, cer, samples = calculate_report(ground_truths, predictions, distances, losses)
+            mean_loss = np.mean(losses)
 
-    wer, cer, samples = calculate_report(ground_truths, predictions, distances, losses)
-    mean_loss = np.mean(losses)
+            # Take only the first report_count items
+            report_samples = itertools.islice(samples, FLAGS.report_count)
 
-    # Take only the first report_count items
-    report_samples = itertools.islice(samples, FLAGS.report_count)
+            print('Test on %s - WER: %f, CER: %f, loss: %f' %
+                  (dataset, wer, cer, mean_loss))
+            print('-' * 80)
+            for sample in report_samples:
+                print('WER: %f, CER: %f, loss: %f' %
+                      (sample.wer, sample.distance, sample.loss))
+                print(' - src: "%s"' % sample.src)
+                print(' - res: "%s"' % sample.res)
+                print('-' * 80)
 
-    print('Test - WER: %f, CER: %f, loss: %f' %
-          (wer, cer, mean_loss))
-    print('-' * 80)
-    for sample in report_samples:
-        print('WER: %f, CER: %f, loss: %f' %
-              (sample.wer, sample.distance, sample.loss))
-        print(' - src: "%s"' % sample.src)
-        print(' - res: "%s"' % sample.res)
-        print('-' * 80)
+            return samples
 
-    return samples
+        samples = []
+        for csv, init_op in zip(test_csvs, test_init_ops):
+            print('Testing model on {}'.format(csv))
+            samples.extend(run_test(init_op, dataset=csv))
+        return samples
 
 
 def main(_):
diff --git a/util/flags.py b/util/flags.py
index f9d3bcd3f8..fc0eaafae5 100644
--- a/util/flags.py
+++ b/util/flags.py
@@ -17,7 +17,6 @@ def create_flags():
     f.DEFINE_string('test_files', '', 'comma separated list of files specifying the dataset used for testing. Multiple files will get merged. If empty, the model will not be tested.')
 
     f.DEFINE_string('train_cached_features_path', '', 'comma separated list of files specifying the dataset used for training. multiple files will get merged')
-    f.DEFINE_string('test_cached_features_path', '', 'comma separated list of files specifying the dataset used for testing. multiple files will get merged')
 
     f.DEFINE_integer('feature_win_len', 32, 'feature extraction audio window length in milliseconds')
     f.DEFINE_integer('feature_win_step', 20, 'feature extraction window step length in milliseconds')

From bfa070e6c305a1260fbc159ee1f8f3a8bf257f32 Mon Sep 17 00:00:00 2001
From: Reuben Morais <reuben.morais@gmail.com>
Date: Thu, 11 Apr 2019 11:07:52 -0300
Subject: [PATCH 4/6] Compute weighted average of individual dev set losses

---
 DeepSpeech.py | 12 +++++++-----
 1 file changed, 7 insertions(+), 5 deletions(-)

diff --git a/DeepSpeech.py b/DeepSpeech.py
index 7cb894583d..125b155fbd 100755
--- a/DeepSpeech.py
+++ b/DeepSpeech.py
@@ -506,7 +506,7 @@ def __call__(self, progress, data, **kwargs):
 
             pbar.finish()
             mean_loss = total_loss / step_count if step_count > 0 else 0.0
-            return mean_loss
+            return mean_loss, step_count
 
         log_info('STARTING Optimization')
         train_start_time = datetime.utcnow()
@@ -516,19 +516,21 @@ def __call__(self, progress, data, **kwargs):
             for epoch in range(FLAGS.epochs):
                 # Training
                 log_progress('Training epoch %d...' % epoch)
-                train_loss = run_set('train', epoch, train_init_op)
+                train_loss, _ = run_set('train', epoch, train_init_op)
                 log_progress('Finished training epoch %d - loss: %f' % (epoch, train_loss))
                 checkpoint_saver.save(session, checkpoint_path, global_step=global_step)
 
                 if FLAGS.dev_files:
                     # Validation
                     dev_loss = 0.0
+                    total_steps = 0
                     for csv, init_op in zip(dev_csvs, dev_init_ops):
                         log_progress('Validating epoch %d on %s...' % (epoch, csv))
-                        set_loss = run_set('dev', epoch, init_op, dataset=csv)
-                        dev_loss += set_loss
+                        set_loss, steps = run_set('dev', epoch, init_op, dataset=csv)
+                        dev_loss += set_loss * steps
+                        total_steps += steps
                         log_progress('Finished validating epoch %d on %s - loss: %f' % (epoch, csv, set_loss))
-                    dev_loss = dev_loss / len(dev_csvs)
+                    dev_loss = dev_loss / total_steps
 
                     dev_losses.append(dev_loss)
 

From 9586fbbd305bb217dfed2ccb99822193b52d6f9d Mon Sep 17 00:00:00 2001
From: Reuben Morais <reuben.morais@gmail.com>
Date: Thu, 11 Apr 2019 15:42:27 -0300
Subject: [PATCH 5/6] Rename --train_cached_features_path to --feature_cache

---
 DeepSpeech.py             | 2 +-
 bin/run-tc-ldc93s1_new.sh | 2 +-
 util/flags.py             | 2 +-
 3 files changed, 3 insertions(+), 3 deletions(-)

diff --git a/DeepSpeech.py b/DeepSpeech.py
index 125b155fbd..15f61c67bd 100755
--- a/DeepSpeech.py
+++ b/DeepSpeech.py
@@ -367,7 +367,7 @@ def train():
     # Create training and validation datasets
     train_set = create_dataset(FLAGS.train_files.split(','),
                                batch_size=FLAGS.train_batch_size,
-                               cache_path=FLAGS.train_cached_features_path)
+                               cache_path=FLAGS.feature_cache)
 
     iterator = tf.data.Iterator.from_structure(train_set.output_types,
                                                train_set.output_shapes,
diff --git a/bin/run-tc-ldc93s1_new.sh b/bin/run-tc-ldc93s1_new.sh
index dc6b6cfdf7..73fc2558ea 100755
--- a/bin/run-tc-ldc93s1_new.sh
+++ b/bin/run-tc-ldc93s1_new.sh
@@ -14,7 +14,7 @@ fi;
 
 python -u DeepSpeech.py --noshow_progressbar --noearly_stop \
   --train_files ${ldc93s1_csv} --train_batch_size 1 \
-  --train_cached_features_path '/tmp/ldc93s1_cache' \
+  --feature_cache '/tmp/ldc93s1_cache' \
   --dev_files ${ldc93s1_csv} --dev_batch_size 1 \
   --test_files ${ldc93s1_csv} --test_batch_size 1 \
   --n_hidden 100 --epochs $epoch_count \
diff --git a/util/flags.py b/util/flags.py
index fc0eaafae5..7973efac23 100644
--- a/util/flags.py
+++ b/util/flags.py
@@ -16,7 +16,7 @@ def create_flags():
     f.DEFINE_string('dev_files', '', 'comma separated list of files specifying the dataset used for validation. Multiple files will get merged. If empty, validation will not be run.')
     f.DEFINE_string('test_files', '', 'comma separated list of files specifying the dataset used for testing. Multiple files will get merged. If empty, the model will not be tested.')
 
-    f.DEFINE_string('train_cached_features_path', '', 'comma separated list of files specifying the dataset used for training. multiple files will get merged')
+    f.DEFINE_string('feature_cache', '', 'path where cached features extracted from --train_files will be saved. If empty, caching will be done in memory and no files will be written.')
 
     f.DEFINE_integer('feature_win_len', 32, 'feature extraction audio window length in milliseconds')
     f.DEFINE_integer('feature_win_step', 20, 'feature extraction window step length in milliseconds')

From 904ab1e288ce2441dc8ad2e48adf7e0b6e8983ca Mon Sep 17 00:00:00 2001
From: Reuben Morais <reuben.morais@gmail.com>
Date: Thu, 11 Apr 2019 16:26:11 -0300
Subject: [PATCH 6/6] Centralize progress logging and progress bar logic

---
 DeepSpeech.py   | 13 ++-----------
 evaluate.py     | 14 +++++++-------
 util/logging.py | 19 +++++++++++++++++++
 3 files changed, 28 insertions(+), 18 deletions(-)

diff --git a/DeepSpeech.py b/DeepSpeech.py
index 15f61c67bd..37cf222a9e 100755
--- a/DeepSpeech.py
+++ b/DeepSpeech.py
@@ -22,7 +22,7 @@
 from util.config import Config, initialize_globals
 from util.feeding import create_dataset, samples_to_mfccs, audiofile_to_features
 from util.flags import create_flags, FLAGS
-from util.logging import log_info, log_error, log_debug
+from util.logging import log_info, log_error, log_debug, log_progress, create_progressbar
 
 
 # Graph Creation
@@ -425,15 +425,6 @@ def train():
 
     initializer = tf.global_variables_initializer()
 
-    # Disable progress logging if needed
-    if FLAGS.show_progressbar:
-        pbar_class = progressbar.ProgressBar
-        def log_progress(*args, **kwargs):
-            pass
-    else:
-        pbar_class = progressbar.NullBar
-        log_progress = log_info
-
     with tf.Session(config=Config.session_config) as session:
         log_debug('Session opened.')
 
@@ -479,7 +470,7 @@ def __call__(self, progress, data, **kwargs):
                        ' | Steps: ', progressbar.widgets.Counter(),
                        ' | ', LossWidget()]
             suffix = ' | Dataset: {}'.format(dataset) if dataset else None
-            pbar = pbar_class(prefix=prefix, widgets=widgets, suffix=suffix, fd=sys.stdout).start()
+            pbar = create_progressbar(prefix=prefix, widgets=widgets, suffix=suffix).start()
 
             # Initialize iterator to the appropriate dataset
             session.run(init_op)
diff --git a/evaluate.py b/evaluate.py
index 203bcfea79..2dc767f88c 100755
--- a/evaluate.py
+++ b/evaluate.py
@@ -18,7 +18,7 @@
 from util.evaluate_tools import calculate_report
 from util.feeding import create_dataset
 from util.flags import create_flags, FLAGS
-from util.logging import log_error
+from util.logging import log_error, log_progress, create_progressbar
 from util.text import levenshtein
 
 
@@ -93,9 +93,9 @@ def run_test(init_op, dataset):
             seq_lengths = []
             ground_truths = []
 
-            bar = progressbar.ProgressBar(prefix='Computing acoustic model predictions | ',
-                                          widgets=['Steps: ', progressbar.Counter(), ' | ', progressbar.Timer()],
-                                          fd=sys.stdout).start()
+            bar = create_progressbar(prefix='Computing acoustic model predictions | ',
+                                     widgets=['Steps: ', progressbar.Counter(), ' | ', progressbar.Timer()]).start()
+            log_progress('Computing acoustic model predictions...')
 
             step_count = 0
 
@@ -121,9 +121,9 @@ def run_test(init_op, dataset):
 
             predictions = []
 
-            bar = progressbar.ProgressBar(max_value=step_count,
-                                          prefix='Decoding predictions | ',
-                                          fd=sys.stdout).start()
+            bar = create_progressbar(max_value=step_count,
+                                     prefix='Decoding predictions | ').start()
+            log_progress('Decoding predictions...')
 
             # Second pass, decode logits and compute WER and edit distance metrics
             for logits, seq_length in bar(zip(logitses, seq_lengths)):
diff --git a/util/logging.py b/util/logging.py
index b6f9ffb99b..c7643a4454 100644
--- a/util/logging.py
+++ b/util/logging.py
@@ -1,5 +1,8 @@
 from __future__ import print_function
 
+import progressbar
+import sys
+
 from util.flags import FLAGS
 
 
@@ -28,3 +31,19 @@ def log_warn(message):
 def log_error(message):
     if FLAGS.log_level <= 3:
         prefix_print('E ', message)
+
+
+def create_progressbar(*args, **kwargs):
+    # Progress bars in stdout by default
+    if 'fd' not in kwargs:
+        kwargs['fd'] = sys.stdout
+
+    if FLAGS.show_progressbar:
+        return progressbar.ProgressBar(*args, **kwargs)
+
+    return progressbar.NullBar(*args, **kwargs)
+
+
+def log_progress(message):
+    if not FLAGS.show_progressbar:
+        log_info(message)