diff --git a/DeepSpeech.py b/DeepSpeech.py index 0da77127de..2cc48d6b48 100755 --- a/DeepSpeech.py +++ b/DeepSpeech.py @@ -665,85 +665,88 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False): ) +def file_relative_read(fname): + return open(os.path.join(os.path.dirname(__file__), fname)).read() + + def export(): r''' Restores the trained variables into a simpler graph that will be exported for serving. ''' log_info('Exporting the model...') - with tf.device('/cpu:0'): - from tensorflow.python.framework.ops import Tensor, Operation + from tensorflow.python.framework.ops import Tensor, Operation - tf.reset_default_graph() - session = tf.Session(config=Config.session_config) + inputs, outputs, _ = create_inference_graph(batch_size=FLAGS.export_batch_size, n_steps=FLAGS.n_steps, tflite=FLAGS.export_tflite) + input_names = ",".join(tensor.op.name for tensor in inputs.values()) + output_names_tensors = [ tensor.op.name for tensor in outputs.values() if isinstance(tensor, Tensor)] + output_names_ops = [ tensor.name for tensor in outputs.values() if isinstance(tensor, Operation)] + output_names = ",".join(output_names_tensors + output_names_ops) + input_shapes = ":".join(",".join(map(str, tensor.shape)) for tensor in inputs.values()) - inputs, outputs, _ = create_inference_graph(batch_size=FLAGS.export_batch_size, n_steps=FLAGS.n_steps, tflite=FLAGS.export_tflite) - input_names = ",".join(tensor.op.name for tensor in inputs.values()) - output_names_tensors = [ tensor.op.name for tensor in outputs.values() if isinstance(tensor, Tensor) ] - output_names_ops = [ tensor.name for tensor in outputs.values() if isinstance(tensor, Operation) ] - output_names = ",".join(output_names_tensors + output_names_ops) - input_shapes = ":".join(",".join(map(str, tensor.shape)) for tensor in inputs.values()) + if not FLAGS.export_tflite: + mapping = {v.op.name: v for v in tf.global_variables() if not v.op.name.startswith('previous_state_')} + else: + # Create a saver using variables from the above newly created graph + def fixup(name): + if name.startswith('rnn/lstm_cell/'): + return name.replace('rnn/lstm_cell/', 'lstm_fused_cell/') + return name - if not FLAGS.export_tflite: - mapping = {v.op.name: v for v in tf.global_variables() if not v.op.name.startswith('previous_state_')} - else: - # Create a saver using variables from the above newly created graph - def fixup(name): - if name.startswith('rnn/lstm_cell/'): - return name.replace('rnn/lstm_cell/', 'lstm_fused_cell/') - return name + mapping = {fixup(v.op.name): v for v in tf.global_variables()} - mapping = {fixup(v.op.name): v for v in tf.global_variables()} + saver = tf.train.Saver(mapping) - saver = tf.train.Saver(mapping) + # Restore variables from training checkpoint + checkpoint = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir) + checkpoint_path = checkpoint.model_checkpoint_path - # Restore variables from training checkpoint - checkpoint = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir) - checkpoint_path = checkpoint.model_checkpoint_path + output_filename = 'output_graph.pb' + if FLAGS.remove_export: + if os.path.isdir(FLAGS.export_dir): + log_info('Removing old export') + shutil.rmtree(FLAGS.export_dir) + try: + output_graph_path = os.path.join(FLAGS.export_dir, output_filename) + + if not os.path.isdir(FLAGS.export_dir): + os.makedirs(FLAGS.export_dir) + + def do_graph_freeze(output_file=None, output_node_names=None, variables_blacklist=None): + return freeze_graph.freeze_graph_with_def_protos( + input_graph_def=tf.get_default_graph().as_graph_def(), + input_saver_def=saver.as_saver_def(), + input_checkpoint=checkpoint_path, + output_node_names=output_node_names, + restore_op_name=None, + filename_tensor_name=None, + output_graph=output_file, + clear_devices=False, + variable_names_blacklist=variables_blacklist, + initializer_nodes='') - output_filename = 'output_graph.pb' - if FLAGS.remove_export: - if os.path.isdir(FLAGS.export_dir): - log_info('Removing old export') - shutil.rmtree(FLAGS.export_dir) - try: - output_graph_path = os.path.join(FLAGS.export_dir, output_filename) - - if not os.path.isdir(FLAGS.export_dir): - os.makedirs(FLAGS.export_dir) - - def do_graph_freeze(output_file=None, output_node_names=None, variables_blacklist=None): - return freeze_graph.freeze_graph_with_def_protos( - input_graph_def=session.graph_def, - input_saver_def=saver.as_saver_def(), - input_checkpoint=checkpoint_path, - output_node_names=output_node_names, - restore_op_name=None, - filename_tensor_name=None, - output_graph=output_file, - clear_devices=False, - variable_names_blacklist=variables_blacklist, - initializer_nodes='') - - if not FLAGS.export_tflite: - do_graph_freeze(output_file=output_graph_path, output_node_names=output_names, variables_blacklist='previous_state_c,previous_state_h') - else: - frozen_graph = do_graph_freeze(output_node_names=output_names, variables_blacklist='') - output_tflite_path = os.path.join(FLAGS.export_dir, output_filename.replace('.pb', '.tflite')) + if not FLAGS.export_tflite: + frozen_graph = do_graph_freeze(output_node_names=output_names, variables_blacklist='previous_state_c,previous_state_h') + frozen_graph.version = int(file_relative_read('GRAPH_VERSION').strip()) + with open(output_graph_path, 'wb') as fout: + fout.write(frozen_graph.SerializeToString()) + else: + frozen_graph = do_graph_freeze(output_node_names=output_names, variables_blacklist='') + output_tflite_path = os.path.join(FLAGS.export_dir, output_filename.replace('.pb', '.tflite')) - converter = tf.lite.TFLiteConverter(frozen_graph, input_tensors=inputs.values(), output_tensors=outputs.values()) - converter.post_training_quantize = True - # AudioSpectrogram and Mfcc ops are custom but have built-in kernels in TFLite - converter.allow_custom_ops = True - tflite_model = converter.convert() + converter = tf.lite.TFLiteConverter(frozen_graph, input_tensors=inputs.values(), output_tensors=outputs.values()) + converter.post_training_quantize = True + # AudioSpectrogram and Mfcc ops are custom but have built-in kernels in TFLite + converter.allow_custom_ops = True + tflite_model = converter.convert() - with open(output_tflite_path, 'wb') as fout: - fout.write(tflite_model) + with open(output_tflite_path, 'wb') as fout: + fout.write(tflite_model) - log_info('Exported model for TF Lite engine as {}'.format(os.path.basename(output_tflite_path))) + log_info('Exported model for TF Lite engine as {}'.format(os.path.basename(output_tflite_path))) - log_info('Models exported at %s' % (FLAGS.export_dir)) - except RuntimeError as e: - log_error(str(e)) + log_info('Models exported at %s' % (FLAGS.export_dir)) + except RuntimeError as e: + log_error(str(e)) def do_single_file_inference(input_file_path): @@ -795,18 +798,20 @@ def main(_): initialize_globals() if FLAGS.train: - with tf.Graph().as_default(): - tf.set_random_seed(FLAGS.random_seed) - train() + tf.reset_default_graph() + tf.set_random_seed(FLAGS.random_seed) + train() if FLAGS.test: - with tf.Graph().as_default(): - test() + tf.reset_default_graph() + test() if FLAGS.export_dir: + tf.reset_default_graph() export() if len(FLAGS.one_shot_infer): + tf.reset_default_graph() do_single_file_inference(FLAGS.one_shot_infer) if __name__ == '__main__' : diff --git a/GRAPH_VERSION b/GRAPH_VERSION new file mode 100644 index 0000000000..56a6051ca2 --- /dev/null +++ b/GRAPH_VERSION @@ -0,0 +1 @@ +1 \ No newline at end of file diff --git a/README.md b/README.md index 28c00b69c0..d3c0f988a6 100644 --- a/README.md +++ b/README.md @@ -33,10 +33,11 @@ See the output of `deepspeech -h` for more information on the use of `deepspeech - [Prerequisites](#prerequisites) - [Getting the code](#getting-the-code) - [Getting the pre-trained model](#getting-the-pre-trained-model) -- [CUDA dependency](#cuda-dependency) - [Using the model](#using-the-model) + - [CUDA dependency](#cuda-dependency) + - [Model compatibility](#model-compatibility) - [Using the Python package](#using-the-python-package) - - [Using the command line client](#using-the-command-line-client) + - [Using the command-line client](#using-the-command-line-client) - [Using the Node.JS package](#using-the-nodejs-package) - [Installing bindings from source](#installing-bindings-from-source) - [Third party bindings](#third-party-bindings) @@ -48,6 +49,7 @@ See the output of `deepspeech -h` for more information on the use of `deepspeech - [Checkpointing](#checkpointing) - [Exporting a model for inference](#exporting-a-model-for-inference) - [Exporting a model for TFLite](#exporting-a-model-for-tflite) + - [Making a mmap-able model for inference](#making-a-mmap-able-model-for-inference) - [Continuing training from a release model](#continuing-training-from-a-release-model) - [Contact/Getting Help](#contactgetting-help) @@ -88,6 +90,10 @@ There are three ways to use DeepSpeech inference: The GPU capable builds (Python, NodeJS, C++ etc) depend on the same CUDA runtime as upstream TensorFlow. Currently with TensorFlow r1.12 it depends on CUDA 9.0 and CuDNN v7.2. +### Model compatibility + +DeepSpeech models are versioned to keep you from trying to use an incompatible graph with a newer client after a breaking change was made to the code. If you get an error saying your model file version is too old for the client, you should either upgrade to a newer model release, re-export your model from the checkpoint using a newer version of the code, or downgrade your client if you need to use the old model and can't re-export it. + ### Using the Python package Pre-built binaries which can be used for performing inference with a trained model can be installed with `pip3`. You can then use the `deepspeech` binary to do speech-to-text on an audio file: diff --git a/native_client/BUILD b/native_client/BUILD index cddf39048f..bf4e1d2654 100644 --- a/native_client/BUILD +++ b/native_client/BUILD @@ -21,6 +21,14 @@ genrule( local = 1, ) +genrule( + name = "ds_graph_version", + outs = ["ds_graph_version.h"], + cmd = "$(location :ds_graph_version.sh) >$@", + tools = [":ds_graph_version.sh"], + local = 1, +) + KENLM_SOURCES = glob(["kenlm/lm/*.cc", "kenlm/util/*.cc", "kenlm/util/double-conversion/*.cc", "kenlm/lm/*.hh", "kenlm/util/*.hh", "kenlm/util/double-conversion/*.h"], exclude = ["kenlm/*/*test.cc", "kenlm/*/*main.cc"]) @@ -62,7 +70,8 @@ tf_cc_shared_object( srcs = ["deepspeech.cc", "deepspeech.h", "alphabet.h", - "ds_version.h"] + + "ds_version.h", + "ds_graph_version.h"] + DECODER_SOURCES, copts = select({ # -fvisibility=hidden is not required on Windows, MSCV hides all declarations by default diff --git a/native_client/deepspeech.cc b/native_client/deepspeech.cc index 5fe174f274..b3451433a6 100644 --- a/native_client/deepspeech.cc +++ b/native_client/deepspeech.cc @@ -13,6 +13,7 @@ #include "alphabet.h" #include "native_client/ds_version.h" +#include "native_client/ds_graph_version.h" #ifndef USE_TFLITE #include "tensorflow/core/public/session.h" @@ -654,6 +655,16 @@ DS_CreateModel(const char* aModelPath, return DS_ERR_FAIL_CREATE_SESS; } + int graph_version = model->graph_def.version(); + if (graph_version < DS_GRAPH_VERSION) { + std::cerr << "Specified model file version (" << graph_version << ") is " + << "incompatible with minimum version supported by this client (" + << DS_GRAPH_VERSION << "). See " + << "https://github.com/mozilla/DeepSpeech/#model-compatibility " + << "for more information" << std::endl; + return DS_ERR_MODEL_INCOMPATIBLE; + } + for (int i = 0; i < model->graph_def.node_size(); ++i) { NodeDef node = model->graph_def.node(i); if (node.name() == "input_node") { diff --git a/native_client/deepspeech.h b/native_client/deepspeech.h index 3ace31515a..eb8b230fbe 100644 --- a/native_client/deepspeech.h +++ b/native_client/deepspeech.h @@ -40,6 +40,7 @@ enum DeepSpeech_Error_Codes DS_ERR_INVALID_ALPHABET = 0x2000, DS_ERR_INVALID_SHAPE = 0x2001, DS_ERR_INVALID_LM = 0x2002, + DS_ERR_MODEL_INCOMPATIBLE = 0x2003, // Runtime failures DS_ERR_FAIL_INIT_MMAP = 0x3000, diff --git a/tc-tests-utils.sh b/tc-tests-utils.sh index ea698b114a..3206dde8ff 100755 --- a/tc-tests-utils.sh +++ b/tc-tests-utils.sh @@ -68,6 +68,13 @@ assert_correct_inference() phrase=$(strip "$1") expected=$(strip "$2") + case "${phrase}" in + *"incompatible with minimum version"*) + echo "Prod model too old for client, skipping test." + return 0 + ;; + esac + if [ -z "${phrase}" -o -z "${expected}" ]; then echo "One or more empty strings:" echo "phrase: <${phrase}>"