Skip to content

Commit

Permalink
Embed more metadata in exported model and read it in native client
Browse files Browse the repository at this point in the history
  • Loading branch information
reuben committed Apr 5, 2019
1 parent 5745089 commit 7f6fd8b
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 30 deletions.
15 changes: 13 additions & 2 deletions DeepSpeech.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,9 +548,9 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
batch_size = batch_size if batch_size > 0 else None

# Create feature computation graph
input_samples = tf.placeholder(tf.float32, [512], 'input_samples')
input_samples = tf.placeholder(tf.float32, [Config.audio_window_samples], 'input_samples')
samples = tf.expand_dims(input_samples, -1)
mfccs, _ = samples_to_mfccs(samples, 16000)
mfccs, _ = samples_to_mfccs(samples, FLAGS.audio_sample_rate)
mfccs = tf.identity(mfccs, name='mfccs')

# Input tensor will be of shape [batch_size, n_steps, 2*n_context+1, n_input]
Expand Down Expand Up @@ -724,6 +724,17 @@ def do_graph_freeze(output_file=None, output_node_names=None, variables_blacklis
if not FLAGS.export_tflite:
frozen_graph = do_graph_freeze(output_node_names=output_names, variables_blacklist='previous_state_c,previous_state_h')
frozen_graph.version = int(file_relative_read('GRAPH_VERSION').strip())

# Add a no-op node to the graph with metadata information to be loaded by the native client
metadata = frozen_graph.node.add()
metadata.name = 'model_metadata'
metadata.op = 'NoOp'
metadata.attr['sample_rate'].i = FLAGS.audio_sample_rate
metadata.attr['feature_win_len'].i = FLAGS.feature_win_len
metadata.attr['feature_win_step'].i = FLAGS.feature_win_step
if FLAGS.export_model_language:
metadata.attr['language'].s = FLAGS.export_model_language.encode('ascii')

with open(output_graph_path, 'wb') as fout:
fout.write(frozen_graph.SerializeToString())
else:
Expand Down
51 changes: 24 additions & 27 deletions native_client/deepspeech.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,25 +39,9 @@
//TODO: infer batch size from model/use dynamic batch size
constexpr unsigned int BATCH_SIZE = 1;

//TODO: use dynamic sample rate
constexpr unsigned int SAMPLE_RATE = 16000;

constexpr float AUDIO_WIN_LEN = 0.032f;
constexpr float AUDIO_WIN_STEP = 0.02f;
constexpr unsigned int AUDIO_WIN_LEN_SAMPLES = (unsigned int)(AUDIO_WIN_LEN * SAMPLE_RATE);
constexpr unsigned int AUDIO_WIN_STEP_SAMPLES = (unsigned int)(AUDIO_WIN_STEP * SAMPLE_RATE);

constexpr size_t WINDOW_SIZE = AUDIO_WIN_LEN * SAMPLE_RATE;

std::array<float, WINDOW_SIZE> calc_hamming_window() {
std::array<float, WINDOW_SIZE> a{0};
for (int i = 0; i < WINDOW_SIZE; ++i) {
a[i] = 0.54 - 0.46 * std::cos(2*M_PI*i/(WINDOW_SIZE-1));
}
return a;
}

std::array<float, WINDOW_SIZE> hamming_window = calc_hamming_window();
constexpr unsigned int DEFAULT_SAMPLE_RATE = 16000;
constexpr unsigned int DEFAULT_WINDOW_LENGTH = DEFAULT_SAMPLE_RATE * 0.032;
constexpr unsigned int DEFAULT_WINDOW_STEP = DEFAULT_SAMPLE_RATE * 0.02;

#ifndef USE_TFLITE
using namespace tensorflow;
Expand Down Expand Up @@ -134,6 +118,9 @@ struct ModelState {
unsigned int n_context;
unsigned int n_features;
unsigned int mfcc_feats_per_timestep;
unsigned int sample_rate;
unsigned int audio_win_len;
unsigned int audio_win_step;

#ifdef USE_TFLITE
size_t previous_state_size;
Expand Down Expand Up @@ -220,6 +207,9 @@ ModelState::ModelState()
, n_context(-1)
, n_features(-1)
, mfcc_feats_per_timestep(-1)
, sample_rate(DEFAULT_SAMPLE_RATE)
, audio_win_len(DEFAULT_WINDOW_LENGTH)
, audio_win_step(DEFAULT_WINDOW_STEP)
#ifdef USE_TFLITE
, previous_state_size(0)
, previous_state_c_(nullptr)
Expand Down Expand Up @@ -258,7 +248,7 @@ StreamingState::feedAudioContent(const short* buffer,
{
// Consume all the data that was passed in, processing full buffers if needed
while (buffer_size > 0) {
while (buffer_size > 0 && audio_buffer.size() < AUDIO_WIN_LEN_SAMPLES) {
while (buffer_size > 0 && audio_buffer.size() < model->audio_win_len) {
// Convert i16 sample into f32
float multiplier = 1.0f / (1 << 15);
audio_buffer.push_back((float)(*buffer) * multiplier);
Expand All @@ -267,10 +257,10 @@ StreamingState::feedAudioContent(const short* buffer,
}

// If the buffer is full, process and shift it
if (audio_buffer.size() == AUDIO_WIN_LEN_SAMPLES) {
if (audio_buffer.size() == model->audio_win_len) {
processAudioWindow(audio_buffer);
// Shift data by one step
shift_buffer_left(audio_buffer, AUDIO_WIN_STEP_SAMPLES);
shift_buffer_left(audio_buffer, model->audio_win_step);
}

// Repeat until buffer empty
Expand Down Expand Up @@ -461,13 +451,13 @@ void
ModelState::compute_mfcc(const vector<float>& samples, vector<float>& mfcc_output)
{
#ifndef USE_TFLITE
Tensor input(DT_FLOAT, TensorShape({AUDIO_WIN_LEN_SAMPLES}));
Tensor input(DT_FLOAT, TensorShape({audio_win_len}));
auto input_mapped = input.flat<float>();
int i;
for (i = 0; i < samples.size(); ++i) {
input_mapped(i) = samples[i];
}
for (; i < AUDIO_WIN_LEN_SAMPLES; ++i) {
for (; i < audio_win_len; ++i) {
input_mapped(i) = 0.f;
}

Expand Down Expand Up @@ -556,8 +546,8 @@ ModelState::decode_metadata(const vector<float>& logits)
for (int i = 0; i < out[0].tokens.size(); ++i) {
metadata->items[i].character = (char*)alphabet->StringFromLabel(out[0].tokens[i]).c_str();
metadata->items[i].timestep = out[0].timesteps[i];
metadata->items[i].start_time = static_cast<float>(out[0].timesteps[i] * AUDIO_WIN_STEP);
metadata->items[i].start_time = out[0].timesteps[i] * ((float)audio_win_step / sample_rate);

if (metadata->items[i].start_time < 0) {
metadata->items[i].start_time = 0;
}
Expand Down Expand Up @@ -700,6 +690,13 @@ DS_CreateModel(const char* aModelPath,
<< std::endl;
return DS_ERR_INVALID_ALPHABET;
}
} else if (node.name() == "model_metadata") {
int sample_rate = node.attr().at("sample_rate").i();
model->sample_rate = sample_rate;
int win_len_ms = node.attr().at("feature_win_len").i();
int win_step_ms = node.attr().at("feature_win_step").i();
model->audio_win_len = sample_rate * (win_len_ms / 1000.0);
model->audio_win_step = sample_rate * (win_step_ms / 1000.0);
}
}

Expand Down Expand Up @@ -833,7 +830,7 @@ DS_SetupStream(ModelState* aCtx,

ctx->accumulated_logits.reserve(aPreAllocFrames * BATCH_SIZE * num_classes);

ctx->audio_buffer.reserve(AUDIO_WIN_LEN_SAMPLES);
ctx->audio_buffer.reserve(aCtx->audio_win_len);
ctx->mfcc_buffer.reserve(aCtx->mfcc_feats_per_timestep);
ctx->mfcc_buffer.resize(aCtx->n_features*aCtx->n_context, 0.f);
ctx->batch_buffer.reserve(aCtx->n_steps * aCtx->mfcc_feats_per_timestep);
Expand Down
6 changes: 6 additions & 0 deletions util/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,12 @@ def initialize_globals():
# Units in the sixth layer = number of characters in the target language plus one
c.n_hidden_6 = c.alphabet.size() + 1 # +1 for CTC blank label

# Size of audio window in samples
c.audio_window_samples = FLAGS.audio_sample_rate * (FLAGS.feature_win_len / 1000)

# Stride for feature computations in samples
c.audio_step_samples = FLAGS.audio_sample_rate * (FLAGS.feature_win_step / 1000)

if FLAGS.one_shot_infer:
if not os.path.exists(FLAGS.one_shot_infer):
log_error('Path specified in --one_shot_infer is not a valid file.')
Expand Down
5 changes: 4 additions & 1 deletion util/feeding.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@ def read_csvs(csv_files):


def samples_to_mfccs(samples, sample_rate):
spectrogram = contrib_audio.audio_spectrogram(samples, window_size=512, stride=320, magnitude_squared=True)
spectrogram = contrib_audio.audio_spectrogram(samples,
window_size=Config.audio_window_samples,
stride=Config.audio_step_samples,
magnitude_squared=True)
mfccs = contrib_audio.mfcc(spectrogram, sample_rate, dct_coefficient_count=Config.n_input)
mfccs = tf.reshape(mfccs, [-1, Config.n_input])

Expand Down
5 changes: 5 additions & 0 deletions util/flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ def create_flags():
tf.app.flags.DEFINE_string ('dev_cached_features_path', '', 'comma separated list of files specifying the dataset used for validation. multiple files will get merged')
tf.app.flags.DEFINE_string ('test_cached_features_path', '', 'comma separated list of files specifying the dataset used for testing. multiple files will get merged')

tf.app.flags.DEFINE_integer ('feature_win_len', 32, 'feature extraction audio window length in milliseconds')
tf.app.flags.DEFINE_integer ('feature_win_step', 20, 'feature extraction window step length in milliseconds')
tf.app.flags.DEFINE_integer ('audio_sample_rate',16000, 'sample rate value expected by model')

# Global Constants
# ================

Expand Down Expand Up @@ -73,6 +77,7 @@ def create_flags():
tf.app.flags.DEFINE_boolean ('export_tflite', False, 'export a graph ready for TF Lite engine')
tf.app.flags.DEFINE_boolean ('use_seq_length', True, 'have sequence_length in the exported graph (will make tfcompile unhappy)')
tf.app.flags.DEFINE_integer ('n_steps', 16, 'how many timesteps to process at once by the export graph, higher values mean more latency')
tf.app.flags.DEFINE_string ('export_model_language', '', 'language the model was trained on. Gets embedded into exported model.')

# Reporting

Expand Down

0 comments on commit 7f6fd8b

Please sign in to comment.