Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Perform separate validation and test epochs per dataset when multiple files are specified (Fixes #1634 and #2043) #2038

Merged
merged 6 commits into from
Apr 16, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 33 additions & 29 deletions DeepSpeech.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,15 @@
import shutil
import tensorflow as tf

from datetime import datetime
from ds_ctcdecoder import ctc_beam_search_decoder, Scorer
from evaluate import evaluate
from six.moves import zip, range
from tensorflow.python.tools import freeze_graph
from util.config import Config, initialize_globals
from util.feeding import create_dataset, samples_to_mfccs, audiofile_to_features
from util.flags import create_flags, FLAGS
from util.logging import log_info, log_error, log_debug
from util.logging import log_info, log_error, log_debug, log_progress, create_progressbar


# Graph Creation
Expand Down Expand Up @@ -366,7 +367,7 @@ def train():
# Create training and validation datasets
train_set = create_dataset(FLAGS.train_files.split(','),
batch_size=FLAGS.train_batch_size,
cache_path=FLAGS.train_cached_features_path)
cache_path=FLAGS.feature_cache)

iterator = tf.data.Iterator.from_structure(train_set.output_types,
train_set.output_shapes,
Expand All @@ -376,10 +377,9 @@ def train():
train_init_op = iterator.make_initializer(train_set)

if FLAGS.dev_files:
dev_set = create_dataset(FLAGS.dev_files.split(','),
batch_size=FLAGS.dev_batch_size,
cache_path=FLAGS.dev_cached_features_path)
dev_init_op = iterator.make_initializer(dev_set)
dev_csvs = FLAGS.dev_files.split(',')
dev_sets = [create_dataset([csv], batch_size=FLAGS.dev_batch_size) for csv in dev_csvs]
dev_init_ops = [iterator.make_initializer(dev_set) for dev_set in dev_sets]

# Dropout
dropout_rates = [tf.placeholder(tf.float32, name='dropout_{}'.format(i)) for i in range(6)]
Expand Down Expand Up @@ -445,7 +445,7 @@ def train():
' - consider using load option "auto" or "init".' % FLAGS.load)
sys.exit(1)

def run_set(set_name, init_op):
def run_set(set_name, epoch, init_op, dataset=None):
is_train = set_name == 'train'
train_op = apply_gradient_op if is_train else []
feed_dict = dropout_feed_dict if is_train else no_dropout_feed_dict
Expand All @@ -456,6 +456,7 @@ def run_set(set_name, init_op):
step_summary_writer = step_summary_writers.get(set_name)
checkpoint_time = time.time()

# Setup progress bar
class LossWidget(progressbar.widgets.FormatLabel):
def __init__(self):
progressbar.widgets.FormatLabel.__init__(self, format='Loss: %(mean_loss)f')
Expand All @@ -464,12 +465,12 @@ def __call__(self, progress, data, **kwargs):
data['mean_loss'] = total_loss / step_count if step_count else 0.0
return progressbar.widgets.FormatLabel.__call__(self, progress, data, **kwargs)

if FLAGS.show_progressbar:
pbar = progressbar.ProgressBar(widgets=['Epoch {}'.format(epoch),
' | ', progressbar.widgets.Timer(),
' | Steps: ', progressbar.widgets.Counter(),
' | ', LossWidget()])
pbar.start()
prefix = 'Epoch {} | {:>10}'.format(epoch, 'Training' if is_train else 'Validation')
widgets = [' | ', progressbar.widgets.Timer(),
' | Steps: ', progressbar.widgets.Counter(),
' | ', LossWidget()]
suffix = ' | Dataset: {}'.format(dataset) if dataset else None
pbar = create_progressbar(prefix=prefix, widgets=widgets, suffix=suffix).start()

# Initialize iterator to the appropriate dataset
session.run(init_op)
Expand All @@ -486,40 +487,42 @@ def __call__(self, progress, data, **kwargs):
total_loss += batch_loss
step_count += 1

if FLAGS.show_progressbar:
pbar.update(step_count)
pbar.update(step_count)

step_summary_writer.add_summary(step_summary, current_step)

if is_train and FLAGS.checkpoint_secs > 0 and time.time() - checkpoint_time > FLAGS.checkpoint_secs:
checkpoint_saver.save(session, checkpoint_path, global_step=current_step)
checkpoint_time = time.time()

if FLAGS.show_progressbar:
pbar.finish()

return total_loss / step_count
pbar.finish()
mean_loss = total_loss / step_count if step_count > 0 else 0.0
return mean_loss, step_count

log_info('STARTING Optimization')
train_start_time = datetime.utcnow()
best_dev_loss = float('inf')
dev_losses = []
try:
for epoch in range(FLAGS.epochs):
# Training
if not FLAGS.show_progressbar:
log_info('Training epoch %d...' % epoch)
train_loss = run_set('train', train_init_op)
if not FLAGS.show_progressbar:
log_info('Finished training epoch %d - loss: %f' % (epoch, train_loss))
log_progress('Training epoch %d...' % epoch)
train_loss, _ = run_set('train', epoch, train_init_op)
log_progress('Finished training epoch %d - loss: %f' % (epoch, train_loss))
checkpoint_saver.save(session, checkpoint_path, global_step=global_step)

if FLAGS.dev_files:
# Validation
if not FLAGS.show_progressbar:
log_info('Validating epoch %d...' % epoch)
dev_loss = run_set('dev', dev_init_op)
if not FLAGS.show_progressbar:
log_info('Finished validating epoch %d - loss: %f' % (epoch, dev_loss))
dev_loss = 0.0
total_steps = 0
for csv, init_op in zip(dev_csvs, dev_init_ops):
log_progress('Validating epoch %d on %s...' % (epoch, csv))
set_loss, steps = run_set('dev', epoch, init_op, dataset=csv)
dev_loss += set_loss * steps
total_steps += steps
log_progress('Finished validating epoch %d on %s - loss: %f' % (epoch, csv, set_loss))
dev_loss = dev_loss / total_steps

dev_losses.append(dev_loss)

if dev_loss < best_dev_loss:
Expand All @@ -543,6 +546,7 @@ def __call__(self, progress, data, **kwargs):
break
except KeyboardInterrupt:
pass
log_info('FINISHED optimization in {}'.format(datetime.utcnow() - train_start_time))
log_debug('Session closed.')


Expand Down
2 changes: 1 addition & 1 deletion bin/run-tc-ldc93s1_new.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ fi;

python -u DeepSpeech.py --noshow_progressbar --noearly_stop \
--train_files ${ldc93s1_csv} --train_batch_size 1 \
--train_cached_features_path '/tmp/ldc93s1_cache' \
--feature_cache '/tmp/ldc93s1_cache' \
--dev_files ${ldc93s1_csv} --dev_batch_size 1 \
--test_files ${ldc93s1_csv} --test_batch_size 1 \
--n_hidden 100 --epochs $epoch_count \
Expand Down
129 changes: 71 additions & 58 deletions evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from util.evaluate_tools import calculate_report
from util.feeding import create_dataset
from util.flags import create_flags, FLAGS
from util.logging import log_error
from util.logging import log_error, log_progress, create_progressbar
from util.text import levenshtein


Expand All @@ -45,12 +45,14 @@ def evaluate(test_csvs, create_model, try_loading):
FLAGS.lm_binary_path, FLAGS.lm_trie_path,
Config.alphabet)

test_set = create_dataset(test_csvs,
batch_size=FLAGS.test_batch_size,
cache_path=FLAGS.test_cached_features_path)
it = test_set.make_one_shot_iterator()
test_csvs = FLAGS.test_files.split(',')
test_sets = [create_dataset([csv], batch_size=FLAGS.test_batch_size) for csv in test_csvs]
iterator = tf.data.Iterator.from_structure(test_sets[0].output_types,
test_sets[0].output_shapes,
output_classes=test_sets[0].output_classes)
test_init_ops = [iterator.make_initializer(test_set) for test_set in test_sets]

(batch_x, batch_x_len), batch_y = it.get_next()
(batch_x, batch_x_len), batch_y = iterator.get_next()

# One rate per layer
no_dropout = [None] * 6
Expand All @@ -67,10 +69,16 @@ def evaluate(test_csvs, create_model, try_loading):

tf.train.get_or_create_global_step()

with tf.Session(config=Config.session_config) as session:
# Create a saver using variables from the above newly created graph
saver = tf.train.Saver()
# Get number of accessible CPU cores for this process
try:
num_processes = cpu_count()
except NotImplementedError:
num_processes = 1

# Create a saver using variables from the above newly created graph
saver = tf.train.Saver()

with tf.Session(config=Config.session_config) as session:
# Restore variables from training checkpoint
loaded = try_loading(session, saver, 'best_dev_checkpoint', 'best validation')
if not loaded:
Expand All @@ -79,70 +87,75 @@ def evaluate(test_csvs, create_model, try_loading):
log_error('Checkpoint directory ({}) does not contain a valid checkpoint state.'.format(FLAGS.checkpoint_dir))
exit(1)

logitses = []
losses = []
seq_lengths = []
ground_truths = []
def run_test(init_op, dataset):
logitses = []
losses = []
seq_lengths = []
ground_truths = []

print('Computing acoustic model predictions...')
bar = progressbar.ProgressBar(widgets=['Steps: ', progressbar.Counter(), ' | ', progressbar.Timer()])
bar = create_progressbar(prefix='Computing acoustic model predictions | ',
widgets=['Steps: ', progressbar.Counter(), ' | ', progressbar.Timer()]).start()
log_progress('Computing acoustic model predictions...')

step_count = 0
step_count = 0

# First pass, compute losses and transposed logits for decoding
while True:
try:
logits, loss_, lengths, transcripts = session.run([transposed, loss, batch_x_len, batch_y])
except tf.errors.OutOfRangeError:
break
# Initialize iterator to the appropriate dataset
session.run(init_op)

step_count += 1
bar.update(step_count)
# First pass, compute losses and transposed logits for decoding
while True:
try:
logits, loss_, lengths, transcripts = session.run([transposed, loss, batch_x_len, batch_y])
except tf.errors.OutOfRangeError:
break

logitses.append(logits)
losses.extend(loss_)
seq_lengths.append(lengths)
ground_truths.extend(sparse_tensor_value_to_texts(transcripts, Config.alphabet))
step_count += 1
bar.update(step_count)

bar.finish()
logitses.append(logits)
losses.extend(loss_)
seq_lengths.append(lengths)
ground_truths.extend(sparse_tensor_value_to_texts(transcripts, Config.alphabet))

predictions = []
bar.finish()

# Get number of accessible CPU cores for this process
try:
num_processes = cpu_count()
except NotImplementedError:
num_processes = 1
predictions = []

bar = create_progressbar(max_value=step_count,
prefix='Decoding predictions | ').start()
log_progress('Decoding predictions...')

print('Decoding predictions...')
bar = progressbar.ProgressBar(max_value=step_count,
widget=progressbar.AdaptiveETA)
# Second pass, decode logits and compute WER and edit distance metrics
for logits, seq_length in bar(zip(logitses, seq_lengths)):
decoded = ctc_beam_search_decoder_batch(logits, seq_length, Config.alphabet, FLAGS.beam_width,
num_processes=num_processes, scorer=scorer)
predictions.extend(d[0][1] for d in decoded)

# Second pass, decode logits and compute WER and edit distance metrics
for logits, seq_length in bar(zip(logitses, seq_lengths)):
decoded = ctc_beam_search_decoder_batch(logits, seq_length, Config.alphabet, FLAGS.beam_width,
num_processes=num_processes, scorer=scorer)
predictions.extend(d[0][1] for d in decoded)
distances = [levenshtein(a, b) for a, b in zip(ground_truths, predictions)]

distances = [levenshtein(a, b) for a, b in zip(ground_truths, predictions)]
wer, cer, samples = calculate_report(ground_truths, predictions, distances, losses)
mean_loss = np.mean(losses)

wer, cer, samples = calculate_report(ground_truths, predictions, distances, losses)
mean_loss = np.mean(losses)
# Take only the first report_count items
report_samples = itertools.islice(samples, FLAGS.report_count)

# Take only the first report_count items
report_samples = itertools.islice(samples, FLAGS.report_count)
print('Test on %s - WER: %f, CER: %f, loss: %f' %
(dataset, wer, cer, mean_loss))
print('-' * 80)
for sample in report_samples:
print('WER: %f, CER: %f, loss: %f' %
(sample.wer, sample.distance, sample.loss))
print(' - src: "%s"' % sample.src)
print(' - res: "%s"' % sample.res)
print('-' * 80)

print('Test - WER: %f, CER: %f, loss: %f' %
(wer, cer, mean_loss))
print('-' * 80)
for sample in report_samples:
print('WER: %f, CER: %f, loss: %f' %
(sample.wer, sample.distance, sample.loss))
print(' - src: "%s"' % sample.src)
print(' - res: "%s"' % sample.res)
print('-' * 80)
return samples

return samples
samples = []
for csv, init_op in zip(test_csvs, test_init_ops):
print('Testing model on {}'.format(csv))
samples.extend(run_test(init_op, dataset=csv))
return samples


def main(_):
Expand Down
2 changes: 1 addition & 1 deletion util/feeding.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def to_sparse_tuple(sequence):
return indices, sequence, shape


def create_dataset(csvs, batch_size, cache_path):
def create_dataset(csvs, batch_size, cache_path=''):
df = read_csvs(csvs)
df.sort_values(by='wav_filesize', inplace=True)

Expand Down
4 changes: 1 addition & 3 deletions util/flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@ def create_flags():
f.DEFINE_string('dev_files', '', 'comma separated list of files specifying the dataset used for validation. Multiple files will get merged. If empty, validation will not be run.')
f.DEFINE_string('test_files', '', 'comma separated list of files specifying the dataset used for testing. Multiple files will get merged. If empty, the model will not be tested.')

f.DEFINE_string('train_cached_features_path', '', 'comma separated list of files specifying the dataset used for training. multiple files will get merged')
f.DEFINE_string('dev_cached_features_path', '', 'comma separated list of files specifying the dataset used for validation. multiple files will get merged')
f.DEFINE_string('test_cached_features_path', '', 'comma separated list of files specifying the dataset used for testing. multiple files will get merged')
f.DEFINE_string('feature_cache', '', 'path where cached features extracted from --train_files will be saved. If empty, caching will be done in memory and no files will be written.')

f.DEFINE_integer('feature_win_len', 32, 'feature extraction audio window length in milliseconds')
f.DEFINE_integer('feature_win_step', 20, 'feature extraction window step length in milliseconds')
Expand Down
19 changes: 19 additions & 0 deletions util/logging.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from __future__ import print_function

import progressbar
import sys

from util.flags import FLAGS


Expand Down Expand Up @@ -28,3 +31,19 @@ def log_warn(message):
def log_error(message):
if FLAGS.log_level <= 3:
prefix_print('E ', message)


def create_progressbar(*args, **kwargs):
# Progress bars in stdout by default
if 'fd' not in kwargs:
kwargs['fd'] = sys.stdout

if FLAGS.show_progressbar:
return progressbar.ProgressBar(*args, **kwargs)

return progressbar.NullBar(*args, **kwargs)


def log_progress(message):
if not FLAGS.show_progressbar:
log_info(message)