Skip to content

Commit

Permalink
Fix TFLite bug in feature computation graph and clean up deepspeech.c…
Browse files Browse the repository at this point in the history
…c a bit
  • Loading branch information
reuben committed Apr 3, 2019
1 parent a7cda8e commit 232df74
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 59 deletions.
7 changes: 2 additions & 5 deletions DeepSpeech.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,11 +548,10 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
batch_size = batch_size if batch_size > 0 else None

# Create feature computation graph
input_samples = tf.placeholder(tf.float32, [None], 'input_samples')
input_samples = tf.placeholder(tf.float32, [512], 'input_samples')
samples = tf.expand_dims(input_samples, -1)
mfccs, mfccs_len = samples_to_mfccs(samples, 16000)
mfccs, _ = samples_to_mfccs(samples, 16000)
mfccs = tf.identity(mfccs, name='mfccs')
mfccs_len = tf.identity(mfccs_len, name='mfccs_len')

# Input tensor will be of shape [batch_size, n_steps, 2*n_context+1, n_input]
# This shape is read by the native_client in DS_CreateModel to know the
Expand Down Expand Up @@ -633,7 +632,6 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
'outputs': logits,
'initialize_state': initialize_state,
'mfccs': mfccs,
'mfccs_len': mfccs_len,
},
layers
)
Expand All @@ -659,7 +657,6 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
'new_state_c': new_state_c,
'new_state_h': new_state_h,
'mfccs': mfccs,
'mfccs_len': mfccs_len,
},
layers
)
Expand Down
118 changes: 64 additions & 54 deletions native_client/deepspeech.cc
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,6 @@ struct ModelState {
int new_state_c_idx;
int new_state_h_idx;
int mfccs_idx;
int mfccs_len_idx;
#endif

ModelState();
Expand All @@ -164,7 +163,7 @@ struct ModelState {
*
* @return String representing the decoded text.
*/
char* decode(vector<float>& logits);
char* decode(const vector<float>& logits);

/**
* @brief Perform decoding of the logits, using basic CTC decoder or
Expand All @@ -186,7 +185,7 @@ struct ModelState {
* @return Metadata struct containing MetadataItem structs for each character.
* The user is responsible for freeing Metadata by calling DS_FreeMetadata().
*/
Metadata* decode_metadata(vector<float>& logits);
Metadata* decode_metadata(const vector<float>& logits);

/**
* @brief Do a single inference step in the acoustic model, with:
Expand All @@ -203,9 +202,6 @@ struct ModelState {
void compute_mfcc(const vector<float>& audio_buffer, vector<float>& mfcc_output);
};

StreamingState* SetupStreamAndFeedAudioContent(ModelState* aCtx, const short* aBuffer,
unsigned int aBufferSize, unsigned int aSampleRate);

ModelState::ModelState()
:
#ifndef USE_TFLITE
Expand Down Expand Up @@ -465,22 +461,27 @@ void
ModelState::compute_mfcc(const vector<float>& samples, vector<float>& mfcc_output)
{
#ifndef USE_TFLITE
Tensor input(DT_FLOAT, TensorShape({static_cast<long long>(samples.size())}));
Tensor input(DT_FLOAT, TensorShape({AUDIO_WIN_LEN_SAMPLES}));
auto input_mapped = input.flat<float>();
for (int i = 0; i < samples.size(); ++i) {
int i;
for (i = 0; i < samples.size(); ++i) {
input_mapped(i) = samples[i];
}
for (; i < AUDIO_WIN_LEN_SAMPLES; ++i) {
input_mapped(i) = 0.f;
}

vector<Tensor> outputs;
Status status = session->Run({{"input_samples", input}}, {"mfccs", "mfccs_len"}, {}, &outputs);
Status status = session->Run({{"input_samples", input}}, {"mfccs"}, {}, &outputs);

if (!status.ok()) {
std::cerr << "Error running session: " << status << "\n";
return;
}

auto mfcc_len_mapped = outputs[1].flat<int32>();
int n_windows = mfcc_len_mapped(0);
// The feature computation graph is hardcoded to one audio length for now
const int n_windows = 1;
assert(outputs[0].shape().num_elemements() / n_features == n_windows);

auto mfcc_mapped = outputs[0].flat<float>();
for (int i = 0; i < n_windows * n_features; ++i) {
Expand All @@ -499,7 +500,14 @@ ModelState::compute_mfcc(const vector<float>& samples, vector<float>& mfcc_outpu
return;
}

int n_windows = *interpreter->typed_tensor<float>(mfccs_len_idx);
// The feature computation graph is hardcoded to one audio length for now
int n_windows = 1;
TfLiteIntArray* out_dims = interpreter->tensor(mfccs_idx)->dims;
int num_elements = 1;
for (int i = 0; i < out_dims->size; ++i) {
num_elements *= out_dims->data[i];
}
assert(num_elements / n_features == n_windows);

float* outputs = interpreter->typed_tensor<float>(mfccs_idx);
for (int i = 0; i < n_windows * n_features; ++i) {
Expand All @@ -509,10 +517,9 @@ ModelState::compute_mfcc(const vector<float>& samples, vector<float>& mfcc_outpu
}

char*
ModelState::decode(vector<float>& logits)
ModelState::decode(const vector<float>& logits)
{
vector<Output> out = ModelState::decode_raw(logits);

return strdup(alphabet->LabelsToString(out[0].tokens).c_str());
}

Expand All @@ -535,7 +542,8 @@ ModelState::decode_raw(const vector<float>& logits)
return out;
}

Metadata* ModelState::decode_metadata(vector<float>& logits)
Metadata*
ModelState::decode_metadata(const vector<float>& logits)
{
vector<Output> out = decode_raw(logits);

Expand All @@ -559,7 +567,8 @@ Metadata* ModelState::decode_metadata(vector<float>& logits)
}

#ifdef USE_TFLITE
int tflite_get_tensor_by_name(const ModelState* ctx, const vector<int>& list, const char* name)
int
tflite_get_tensor_by_name(const ModelState* ctx, const vector<int>& list, const char* name)
{
int rv = -1;

Expand All @@ -574,12 +583,14 @@ int tflite_get_tensor_by_name(const ModelState* ctx, const vector<int>& list, co
return rv;
}

int tflite_get_input_tensor_by_name(const ModelState* ctx, const char* name)
int
tflite_get_input_tensor_by_name(const ModelState* ctx, const char* name)
{
return ctx->interpreter->inputs()[tflite_get_tensor_by_name(ctx, ctx->interpreter->inputs(), name)];
}

int tflite_get_output_tensor_by_name(const ModelState* ctx, const char* name)
int
tflite_get_output_tensor_by_name(const ModelState* ctx, const char* name)
{
return ctx->interpreter->outputs()[tflite_get_tensor_by_name(ctx, ctx->interpreter->outputs(), name)];
}
Expand Down Expand Up @@ -729,7 +740,6 @@ DS_CreateModel(const char* aModelPath,
model->new_state_c_idx = tflite_get_output_tensor_by_name(model.get(), "new_state_c");
model->new_state_h_idx = tflite_get_output_tensor_by_name(model.get(), "new_state_h");
model->mfccs_idx = tflite_get_output_tensor_by_name(model.get(), "mfccs");
model->mfccs_len_idx = tflite_get_output_tensor_by_name(model.get(), "mfccs_len");

TfLiteIntArray* dims_input_node = model->interpreter->tensor(model->input_node_idx)->dims;

Expand Down Expand Up @@ -792,41 +802,6 @@ DS_EnableDecoderWithLM(ModelState* aCtx,
}
}

char*
DS_SpeechToText(ModelState* aCtx,
const short* aBuffer,
unsigned int aBufferSize,
unsigned int aSampleRate)
{
StreamingState* ctx = SetupStreamAndFeedAudioContent(aCtx, aBuffer, aBufferSize, aSampleRate);
return DS_FinishStream(ctx);
}

Metadata*
DS_SpeechToTextWithMetadata(ModelState* aCtx,
const short* aBuffer,
unsigned int aBufferSize,
unsigned int aSampleRate)
{
StreamingState* ctx = SetupStreamAndFeedAudioContent(aCtx, aBuffer, aBufferSize, aSampleRate);
return DS_FinishStreamWithMetadata(ctx);
}

StreamingState*
SetupStreamAndFeedAudioContent(ModelState* aCtx,
const short* aBuffer,
unsigned int aBufferSize,
unsigned int aSampleRate)
{
StreamingState* ctx;
int status = DS_SetupStream(aCtx, 0, aSampleRate, &ctx);
if (status != DS_ERR_OK) {
return nullptr;
}
DS_FeedAudioContent(ctx, aBuffer, aBufferSize);
return ctx;
}

int
DS_SetupStream(ModelState* aCtx,
unsigned int aPreAllocFrames,
Expand Down Expand Up @@ -899,6 +874,41 @@ DS_FinishStreamWithMetadata(StreamingState* aSctx)
return metadata;
}

StreamingState*
SetupStreamAndFeedAudioContent(ModelState* aCtx,
const short* aBuffer,
unsigned int aBufferSize,
unsigned int aSampleRate)
{
StreamingState* ctx;
int status = DS_SetupStream(aCtx, 0, aSampleRate, &ctx);
if (status != DS_ERR_OK) {
return nullptr;
}
DS_FeedAudioContent(ctx, aBuffer, aBufferSize);
return ctx;
}

char*
DS_SpeechToText(ModelState* aCtx,
const short* aBuffer,
unsigned int aBufferSize,
unsigned int aSampleRate)
{
StreamingState* ctx = SetupStreamAndFeedAudioContent(aCtx, aBuffer, aBufferSize, aSampleRate);
return DS_FinishStream(ctx);
}

Metadata*
DS_SpeechToTextWithMetadata(ModelState* aCtx,
const short* aBuffer,
unsigned int aBufferSize,
unsigned int aSampleRate)
{
StreamingState* ctx = SetupStreamAndFeedAudioContent(aCtx, aBuffer, aBufferSize, aSampleRate);
return DS_FinishStreamWithMetadata(ctx);
}

void
DS_DiscardStream(StreamingState* aSctx)
{
Expand Down
11 changes: 11 additions & 0 deletions tc-tests-utils.sh
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,11 @@ assert_shows_something()
fi;

case "${stderr}" in
*"incompatible with minimum version"*)
echo "Prod model too old for client, skipping test."
return 0
;;

*${expected}*)
echo "Proper output has been produced:"
echo "${stderr}"
Expand Down Expand Up @@ -342,10 +347,14 @@ run_all_inference_tests()
set -e
assert_correct_ldc93s1_lm "${phrase_pbmodel_withlm_stereo_44k}" "$status"

set +e
phrase_pbmodel_nolm_mono_8k=$(deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name_mmap} --alphabet ${TASKCLUSTER_TMP_DIR}/alphabet.txt --audio ${TASKCLUSTER_TMP_DIR}/LDC93S1_pcms16le_1_8000.wav 2>&1 1>/dev/null)
set -e
assert_correct_warning_upsampling "${phrase_pbmodel_nolm_mono_8k}"

set +e
phrase_pbmodel_withlm_mono_8k=$(deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name_mmap} --alphabet ${TASKCLUSTER_TMP_DIR}/alphabet.txt --lm ${TASKCLUSTER_TMP_DIR}/lm.binary --trie ${TASKCLUSTER_TMP_DIR}/trie --audio ${TASKCLUSTER_TMP_DIR}/LDC93S1_pcms16le_1_8000.wav 2>&1 1>/dev/null)
set -e
assert_correct_warning_upsampling "${phrase_pbmodel_withlm_mono_8k}"
}

Expand All @@ -369,7 +378,9 @@ run_prod_inference_tests()
set -e
assert_correct_ldc93s1_prodmodel_stereo_44k "${phrase_pbmodel_withlm_stereo_44k}" "$status"

set +e
phrase_pbmodel_withlm_mono_8k=$(deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name_mmap} --alphabet ${TASKCLUSTER_TMP_DIR}/alphabet.txt --lm ${TASKCLUSTER_TMP_DIR}/lm.binary --trie ${TASKCLUSTER_TMP_DIR}/trie --audio ${TASKCLUSTER_TMP_DIR}/LDC93S1_pcms16le_1_8000.wav 2>&1 1>/dev/null)
set -e
assert_correct_warning_upsampling "${phrase_pbmodel_withlm_mono_8k}"
}

Expand Down

0 comments on commit 232df74

Please sign in to comment.