Skip to content

Commit

Permalink
Call toco during export
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexandre Lissy committed Oct 31, 2018
1 parent c67f66f commit 07faba8
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 25 deletions.
79 changes: 55 additions & 24 deletions DeepSpeech.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@
import traceback
import inspect
import progressbar
import tempfile

from functools import partial
from six.moves import zip, range, filter, urllib, BaseHTTPServer
from tensorflow.python.tools import freeze_graph
from tensorflow.contrib.lite.python import tflite_convert
from threading import Thread, Lock
from util.audio import audiofile_to_input_vector
from util.feeding import DataSet, ModelFeeder
Expand Down Expand Up @@ -1853,7 +1855,10 @@ def export():
tf.reset_default_graph()
session = tf.Session(config=session_config)

inputs, outputs = create_inference_graph(batch_size=1, n_steps=FLAGS.n_steps, tflite=FLAGS.export_tflite)
export_batch_size = 1
window_size = 2*n_context+1

inputs, outputs = create_inference_graph(batch_size=export_batch_size, n_steps=FLAGS.n_steps, tflite=FLAGS.export_tflite)

if not FLAGS.export_tflite:
mapping = {v.op.name: v for v in tf.global_variables() if not v.op.name.startswith('previous_state_')}
Expand All @@ -1872,11 +1877,7 @@ def fixup(name):
checkpoint = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
checkpoint_path = checkpoint.model_checkpoint_path

if not FLAGS.export_tflite:
output_filename = 'output_graph.pb'
else:
output_filename = 'output_graph.fb'

output_filename = 'output_graph.pb'
if FLAGS.remove_export:
if os.path.isdir(FLAGS.export_dir):
log_info('Removing old export')
Expand All @@ -1887,31 +1888,61 @@ def fixup(name):
if not os.path.isdir(FLAGS.export_dir):
os.makedirs(FLAGS.export_dir)

def do_graph_freeze(output_file=None, output_node_names=None, variables_blacklist=None):
freeze_graph.freeze_graph_with_def_protos(
input_graph_def=session.graph_def,
input_saver_def=saver.as_saver_def(),
input_checkpoint=checkpoint_path,
output_node_names=output_node_names,
restore_op_name=None,
filename_tensor_name=None,
output_graph=output_file,
clear_devices=False,
variable_names_blacklist=variables_blacklist,
initializer_nodes='')

if not FLAGS.export_tflite:
output_node_names = 'logits,initialize_state'
variables_blacklist = 'previous_state_c,previous_state_h'
do_graph_freeze(output_file=output_graph_path, output_node_names='logits,initialize_state', variables_blacklist='previous_state_c,previous_state_h')
else:
output_node_names = 'logits,new_state_c,new_state_h'
variables_blacklist = ''

# Freeze graph
freeze_graph.freeze_graph_with_def_protos(
input_graph_def=session.graph_def,
input_saver_def=saver.as_saver_def(),
input_checkpoint=checkpoint_path,
output_node_names=output_node_names,
restore_op_name=None,
filename_tensor_name=None,
output_graph=output_graph_path,
clear_devices=False,
variable_names_blacklist=variables_blacklist,
initializer_nodes='')
temp_fd, temp_freeze = tempfile.mkstemp(dir=FLAGS.export_dir)
close(temp_fd)
do_graph_freeze(output_file=output_graph_path, output_node_names='logits,new_state_c,new_state_h', variables_blacklist='')
output_tflite_path = os.path.join(FLAGS.export_dir, output_filename.replace('.pb', '.tflite'))
class TFLiteFlags():
def __init__(self):
self.graph_def_file = temp_freeze
self.inference_type = 'FLOAT'
self.input_arrays = 'input_node,previous_state_c,previous_state_h'
self.input_shapes = '{},{},{},{}:1,{}:1,{}'.format(export_batch_size, FLAGS.n_steps, window_size, n_input, FLAGS.n_hidden, FLAGS.n_hidden)
self.output_arrays = 'logits'
self.output_file = output_tflite_path
self.output_format = 'TFLITE'

default_empty = [
'inference_input_type',
'mean_values',
'default_ranges_min', 'default_ranges_max',
'drop_control_dependency',
'reorder_across_fake_quant',
'change_concat_input_ranges',
'allow_custom_ops',
'converter_mode',
'post_training_quantize',
'dump_graphviz_dir',
'dump_graphviz_video'
]
for e in default_empty:
self.__dict__[e] = None

flags = TFLiteFlags()
tflite_convert._convert_model(flags)
os.unlink(temp_freeze)
log_info('Exported model for TF Lite engine as {}'.format(os.path.basename(output_tflite_path)))

log_info('Models exported at %s' % (FLAGS.export_dir))
except RuntimeError as e:
log_error(str(e))


def do_single_file_inference(input_file_path):
with tf.Session(config=session_config) as session:
inputs, outputs = create_inference_graph(batch_size=1, use_new_decoder=True)
Expand Down
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ See the output of `deepspeech -h` for more information on the use of `deepspeech
- [Training a model](#training-a-model)
- [Checkpointing](#checkpointing)
- [Exporting a model for inference](#exporting-a-model-for-inference)
- [Exporting a model for TFLite](#exporting-a-model-for-tflite)
- [Distributed computing across more than one machine](#distributed-training-across-more-than-one-machine)
- [Continuing training from a release model](#continuing-training-from-a-release-model)
- [Code documentation](#code-documentation)
Expand Down Expand Up @@ -317,6 +318,10 @@ Be aware however that checkpoints are only valid for the same model geometry the
If the `--export_dir` parameter is provided, a model will have been exported to this directory during training.
Refer to the corresponding [README.md](native_client/README.md) for information on building and running a client that can use the exported model.

### Exporting a model for TFLite

If you want to experiment with the TF Lite engine, you need to export a model that is compatible with it, then use the `--export_tflite` flag. If you already have a model, it can be combined with your checkpoint directory and `--notrain --notest`, so you just run the export step. This should produce you a `output_graph.tflite`.

### Making a mmap-able model for inference

The `output_graph.pb` model file generated in the above step will be loaded in memory to be dealt with when running inference.
Expand Down
2 changes: 1 addition & 1 deletion tc-train-tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ pushd ${HOME}/DeepSpeech/ds/
popd

cp /tmp/train/output_graph.pb ${TASKCLUSTER_ARTIFACTS}
cp /tmp/train/output_graph.fb ${TASKCLUSTER_ARTIFACTS}
cp /tmp/train/output_graph.tflite ${TASKCLUSTER_ARTIFACTS}

if [ ! -z "${CONVERT_GRAPHDEF_MEMMAPPED}" ]; then
convert_graphdef=$(basename "${CONVERT_GRAPHDEF_MEMMAPPED}")
Expand Down

0 comments on commit 07faba8

Please sign in to comment.