Skip to content

Commit

Permalink
Add version info to exported graphs
Browse files Browse the repository at this point in the history
  • Loading branch information
reuben committed Apr 2, 2019
1 parent 57e34c2 commit e1d5df6
Show file tree
Hide file tree
Showing 7 changed files with 111 additions and 71 deletions.
141 changes: 73 additions & 68 deletions DeepSpeech.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,85 +665,88 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
)


def file_relative_read(fname):
return open(os.path.join(os.path.dirname(__file__), fname)).read()


def export():
r'''
Restores the trained variables into a simpler graph that will be exported for serving.
'''
log_info('Exporting the model...')
with tf.device('/cpu:0'):
from tensorflow.python.framework.ops import Tensor, Operation
from tensorflow.python.framework.ops import Tensor, Operation

tf.reset_default_graph()
session = tf.Session(config=Config.session_config)
inputs, outputs, _ = create_inference_graph(batch_size=FLAGS.export_batch_size, n_steps=FLAGS.n_steps, tflite=FLAGS.export_tflite)
input_names = ",".join(tensor.op.name for tensor in inputs.values())
output_names_tensors = [ tensor.op.name for tensor in outputs.values() if isinstance(tensor, Tensor)]
output_names_ops = [ tensor.name for tensor in outputs.values() if isinstance(tensor, Operation)]
output_names = ",".join(output_names_tensors + output_names_ops)
input_shapes = ":".join(",".join(map(str, tensor.shape)) for tensor in inputs.values())

inputs, outputs, _ = create_inference_graph(batch_size=FLAGS.export_batch_size, n_steps=FLAGS.n_steps, tflite=FLAGS.export_tflite)
input_names = ",".join(tensor.op.name for tensor in inputs.values())
output_names_tensors = [ tensor.op.name for tensor in outputs.values() if isinstance(tensor, Tensor) ]
output_names_ops = [ tensor.name for tensor in outputs.values() if isinstance(tensor, Operation) ]
output_names = ",".join(output_names_tensors + output_names_ops)
input_shapes = ":".join(",".join(map(str, tensor.shape)) for tensor in inputs.values())
if not FLAGS.export_tflite:
mapping = {v.op.name: v for v in tf.global_variables() if not v.op.name.startswith('previous_state_')}
else:
# Create a saver using variables from the above newly created graph
def fixup(name):
if name.startswith('rnn/lstm_cell/'):
return name.replace('rnn/lstm_cell/', 'lstm_fused_cell/')
return name

if not FLAGS.export_tflite:
mapping = {v.op.name: v for v in tf.global_variables() if not v.op.name.startswith('previous_state_')}
else:
# Create a saver using variables from the above newly created graph
def fixup(name):
if name.startswith('rnn/lstm_cell/'):
return name.replace('rnn/lstm_cell/', 'lstm_fused_cell/')
return name
mapping = {fixup(v.op.name): v for v in tf.global_variables()}

mapping = {fixup(v.op.name): v for v in tf.global_variables()}
saver = tf.train.Saver(mapping)

saver = tf.train.Saver(mapping)
# Restore variables from training checkpoint
checkpoint = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
checkpoint_path = checkpoint.model_checkpoint_path

# Restore variables from training checkpoint
checkpoint = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
checkpoint_path = checkpoint.model_checkpoint_path
output_filename = 'output_graph.pb'
if FLAGS.remove_export:
if os.path.isdir(FLAGS.export_dir):
log_info('Removing old export')
shutil.rmtree(FLAGS.export_dir)
try:
output_graph_path = os.path.join(FLAGS.export_dir, output_filename)

if not os.path.isdir(FLAGS.export_dir):
os.makedirs(FLAGS.export_dir)

def do_graph_freeze(output_file=None, output_node_names=None, variables_blacklist=None):
return freeze_graph.freeze_graph_with_def_protos(
input_graph_def=tf.get_default_graph().as_graph_def(),
input_saver_def=saver.as_saver_def(),
input_checkpoint=checkpoint_path,
output_node_names=output_node_names,
restore_op_name=None,
filename_tensor_name=None,
output_graph=output_file,
clear_devices=False,
variable_names_blacklist=variables_blacklist,
initializer_nodes='')

output_filename = 'output_graph.pb'
if FLAGS.remove_export:
if os.path.isdir(FLAGS.export_dir):
log_info('Removing old export')
shutil.rmtree(FLAGS.export_dir)
try:
output_graph_path = os.path.join(FLAGS.export_dir, output_filename)

if not os.path.isdir(FLAGS.export_dir):
os.makedirs(FLAGS.export_dir)

def do_graph_freeze(output_file=None, output_node_names=None, variables_blacklist=None):
return freeze_graph.freeze_graph_with_def_protos(
input_graph_def=session.graph_def,
input_saver_def=saver.as_saver_def(),
input_checkpoint=checkpoint_path,
output_node_names=output_node_names,
restore_op_name=None,
filename_tensor_name=None,
output_graph=output_file,
clear_devices=False,
variable_names_blacklist=variables_blacklist,
initializer_nodes='')

if not FLAGS.export_tflite:
do_graph_freeze(output_file=output_graph_path, output_node_names=output_names, variables_blacklist='previous_state_c,previous_state_h')
else:
frozen_graph = do_graph_freeze(output_node_names=output_names, variables_blacklist='')
output_tflite_path = os.path.join(FLAGS.export_dir, output_filename.replace('.pb', '.tflite'))
if not FLAGS.export_tflite:
frozen_graph = do_graph_freeze(output_node_names=output_names, variables_blacklist='previous_state_c,previous_state_h')
frozen_graph.version = int(file_relative_read('GRAPH_VERSION').strip())
with open(output_graph_path, 'wb') as fout:
fout.write(frozen_graph.SerializeToString())
else:
frozen_graph = do_graph_freeze(output_node_names=output_names, variables_blacklist='')
output_tflite_path = os.path.join(FLAGS.export_dir, output_filename.replace('.pb', '.tflite'))

converter = tf.lite.TFLiteConverter(frozen_graph, input_tensors=inputs.values(), output_tensors=outputs.values())
converter.post_training_quantize = True
# AudioSpectrogram and Mfcc ops are custom but have built-in kernels in TFLite
converter.allow_custom_ops = True
tflite_model = converter.convert()
converter = tf.lite.TFLiteConverter(frozen_graph, input_tensors=inputs.values(), output_tensors=outputs.values())
converter.post_training_quantize = True
# AudioSpectrogram and Mfcc ops are custom but have built-in kernels in TFLite
converter.allow_custom_ops = True
tflite_model = converter.convert()

with open(output_tflite_path, 'wb') as fout:
fout.write(tflite_model)
with open(output_tflite_path, 'wb') as fout:
fout.write(tflite_model)

log_info('Exported model for TF Lite engine as {}'.format(os.path.basename(output_tflite_path)))
log_info('Exported model for TF Lite engine as {}'.format(os.path.basename(output_tflite_path)))

log_info('Models exported at %s' % (FLAGS.export_dir))
except RuntimeError as e:
log_error(str(e))
log_info('Models exported at %s' % (FLAGS.export_dir))
except RuntimeError as e:
log_error(str(e))


def do_single_file_inference(input_file_path):
Expand Down Expand Up @@ -795,18 +798,20 @@ def main(_):
initialize_globals()

if FLAGS.train:
with tf.Graph().as_default():
tf.set_random_seed(FLAGS.random_seed)
train()
tf.reset_default_graph()
tf.set_random_seed(FLAGS.random_seed)
train()

if FLAGS.test:
with tf.Graph().as_default():
test()
tf.reset_default_graph()
test()

if FLAGS.export_dir:
tf.reset_default_graph()
export()

if len(FLAGS.one_shot_infer):
tf.reset_default_graph()
do_single_file_inference(FLAGS.one_shot_infer)

if __name__ == '__main__' :
Expand Down
1 change: 1 addition & 0 deletions GRAPH_VERSION
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
1
10 changes: 8 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,11 @@ See the output of `deepspeech -h` for more information on the use of `deepspeech
- [Prerequisites](#prerequisites)
- [Getting the code](#getting-the-code)
- [Getting the pre-trained model](#getting-the-pre-trained-model)
- [CUDA dependency](#cuda-dependency)
- [Using the model](#using-the-model)
- [CUDA dependency](#cuda-dependency)
- [Model compatibility](#model-compatibility)
- [Using the Python package](#using-the-python-package)
- [Using the command line client](#using-the-command-line-client)
- [Using the command-line client](#using-the-command-line-client)
- [Using the Node.JS package](#using-the-nodejs-package)
- [Installing bindings from source](#installing-bindings-from-source)
- [Third party bindings](#third-party-bindings)
Expand All @@ -48,6 +49,7 @@ See the output of `deepspeech -h` for more information on the use of `deepspeech
- [Checkpointing](#checkpointing)
- [Exporting a model for inference](#exporting-a-model-for-inference)
- [Exporting a model for TFLite](#exporting-a-model-for-tflite)
- [Making a mmap-able model for inference](#making-a-mmap-able-model-for-inference)
- [Continuing training from a release model](#continuing-training-from-a-release-model)
- [Contact/Getting Help](#contactgetting-help)

Expand Down Expand Up @@ -88,6 +90,10 @@ There are three ways to use DeepSpeech inference:

The GPU capable builds (Python, NodeJS, C++ etc) depend on the same CUDA runtime as upstream TensorFlow. Currently with TensorFlow r1.12 it depends on CUDA 9.0 and CuDNN v7.2.

### Model compatibility

DeepSpeech models are versioned to keep you from trying to use an incompatible graph with a newer client after a breaking change was made to the code. If you get an error saying your model file version is too old for the client, you should either upgrade to a newer model release, re-export your model from the checkpoint using a newer version of the code, or downgrade your client if you need to use the old model and can't re-export it.

### Using the Python package

Pre-built binaries which can be used for performing inference with a trained model can be installed with `pip3`. You can then use the `deepspeech` binary to do speech-to-text on an audio file:
Expand Down
11 changes: 10 additions & 1 deletion native_client/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,14 @@ genrule(
local = 1,
)

genrule(
name = "ds_graph_version",
outs = ["ds_graph_version.h"],
cmd = "$(location :ds_graph_version.sh) >$@",
tools = [":ds_graph_version.sh"],
local = 1,
)

KENLM_SOURCES = glob(["kenlm/lm/*.cc", "kenlm/util/*.cc", "kenlm/util/double-conversion/*.cc",
"kenlm/lm/*.hh", "kenlm/util/*.hh", "kenlm/util/double-conversion/*.h"],
exclude = ["kenlm/*/*test.cc", "kenlm/*/*main.cc"])
Expand Down Expand Up @@ -62,7 +70,8 @@ tf_cc_shared_object(
srcs = ["deepspeech.cc",
"deepspeech.h",
"alphabet.h",
"ds_version.h"] +
"ds_version.h",
"ds_graph_version.h"] +
DECODER_SOURCES,
copts = select({
# -fvisibility=hidden is not required on Windows, MSCV hides all declarations by default
Expand Down
11 changes: 11 additions & 0 deletions native_client/deepspeech.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "alphabet.h"

#include "native_client/ds_version.h"
#include "native_client/ds_graph_version.h"

#ifndef USE_TFLITE
#include "tensorflow/core/public/session.h"
Expand Down Expand Up @@ -654,6 +655,16 @@ DS_CreateModel(const char* aModelPath,
return DS_ERR_FAIL_CREATE_SESS;
}

int graph_version = model->graph_def.version();
if (graph_version < DS_GRAPH_VERSION) {
std::cerr << "Specified model file version (" << graph_version << ") is "
<< "incompatible with minimum version supported by this client ("
<< DS_GRAPH_VERSION << "). See "
<< "https://github.com/mozilla/DeepSpeech/#model-compatibility "
<< "for more information" << std::endl;
return DS_ERR_MODEL_INCOMPATIBLE;
}

for (int i = 0; i < model->graph_def.node_size(); ++i) {
NodeDef node = model->graph_def.node(i);
if (node.name() == "input_node") {
Expand Down
1 change: 1 addition & 0 deletions native_client/deepspeech.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ enum DeepSpeech_Error_Codes
DS_ERR_INVALID_ALPHABET = 0x2000,
DS_ERR_INVALID_SHAPE = 0x2001,
DS_ERR_INVALID_LM = 0x2002,
DS_ERR_MODEL_INCOMPATIBLE = 0x2003,

// Runtime failures
DS_ERR_FAIL_INIT_MMAP = 0x3000,
Expand Down
7 changes: 7 additions & 0 deletions tc-tests-utils.sh
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,13 @@ assert_correct_inference()
phrase=$(strip "$1")
expected=$(strip "$2")

case "${phrase}" in
*"incompatible with minimum version"*)
echo "Prod model too old for client, skipping test."
return 0
;;
esac

if [ -z "${phrase}" -o -z "${expected}" ]; then
echo "One or more empty strings:"
echo "phrase: <${phrase}>"
Expand Down

0 comments on commit e1d5df6

Please sign in to comment.