Skip to content

Commit

Permalink
Compute weighted average of individual dev set losses
Browse files Browse the repository at this point in the history
  • Loading branch information
reuben committed Apr 11, 2019
1 parent d264e9b commit 946828c
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions DeepSpeech.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,7 @@ def __call__(self, progress, data):

pbar.finish()
mean_loss = total_loss / step_count if step_count > 0 else 0.0
return mean_loss
return mean_loss, step_count

log_info('STARTING Optimization')
train_start_time = datetime.utcnow()
Expand All @@ -516,19 +516,21 @@ def __call__(self, progress, data):
for epoch in range(FLAGS.epochs):
# Training
log_progress('Training epoch %d...' % epoch)
train_loss = run_set('train', epoch, train_init_op)
train_loss, _ = run_set('train', epoch, train_init_op)
log_progress('Finished training epoch %d - loss: %f' % (epoch, train_loss))
checkpoint_saver.save(session, checkpoint_path, global_step=global_step)

if FLAGS.dev_files:
# Validation
dev_loss = 0.0
total_steps = 0
for csv, init_op in zip(dev_csvs, dev_init_ops):
log_progress('Validating epoch %d on %s...' % (epoch, csv))
set_loss = run_set('dev', epoch, init_op, dataset=csv)
dev_loss += set_loss
set_loss, steps = run_set('dev', epoch, init_op, dataset=csv)
dev_loss += set_loss * steps
total_steps += steps
log_progress('Finished validating epoch %d on %s - loss: %f' % (epoch, csv, set_loss))
dev_loss = dev_loss / len(dev_csvs)
dev_loss = dev_loss / total_steps

dev_losses.append(dev_loss)

Expand Down

0 comments on commit 946828c

Please sign in to comment.