Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CTC streaming decoder #2121

Merged
merged 1 commit into from
May 22, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
CTC beam search streaming decoder (+6 squashed commits)
Squashed commits:
[2941b47] Fixed nits
[700572e] Restored old CTC decoder API
[5aaf75d] Fixed nits
[969d71a] Added a destructor for DecoderState
[af0be6e] Removed accumulated_logits
[9dcb7b4] CTC beam search streaming decoder
  • Loading branch information
dabinat committed May 22, 2019
commit d9a269412e492ca5046c9ab89f0547be51050159
132 changes: 86 additions & 46 deletions native_client/ctcdecode/ctc_beam_search_decoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,49 +14,61 @@

using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;

std::vector<Output> ctc_beam_search_decoder(
const double *probs,
int time_dim,
int class_dim,
const Alphabet &alphabet,
size_t beam_size,
double cutoff_prob,
size_t cutoff_top_n,
Scorer *ext_scorer) {
DecoderState* decoder_init(const Alphabet &alphabet,
int class_dim,
Scorer* ext_scorer) {

// dimension check
VALID_CHECK_EQ(class_dim, alphabet.GetSize()+1,
"The shape of probs does not match with "
"the shape of the vocabulary");

// assign special ids
int space_id = alphabet.GetSpaceLabel();
int blank_id = alphabet.GetSize();
DecoderState *state = new DecoderState;
state->space_id = alphabet.GetSpaceLabel();
state->blank_id = alphabet.GetSize();

// init prefixes' root
PathTrie root;
root.score = root.log_prob_b_prev = 0.0;
std::vector<PathTrie *> prefixes;
prefixes.push_back(&root);
PathTrie *root = new PathTrie;
root->score = root->log_prob_b_prev = 0.0;

state->prefix_root = root;

state->prefixes.push_back(root);

if (ext_scorer != nullptr && !ext_scorer->is_character_based()) {
auto dict_ptr = ext_scorer->dictionary->Copy(true);
root.set_dictionary(dict_ptr);
root->set_dictionary(dict_ptr);
auto matcher = std::make_shared<FSTMATCH>(*dict_ptr, fst::MATCH_INPUT);
root.set_matcher(matcher);
root->set_matcher(matcher);
}

return state;
}

void decoder_next(const double *probs,
const Alphabet &alphabet,
DecoderState *state,
int time_dim,
int class_dim,
double cutoff_prob,
size_t cutoff_top_n,
size_t beam_size,
Scorer *ext_scorer) {

// prefix search over time
// prefix search over time
for (size_t time_step = 0; time_step < time_dim; ++time_step) {
auto *prob = &probs[time_step*class_dim];

float min_cutoff = -NUM_FLT_INF;
bool full_beam = false;
if (ext_scorer != nullptr) {
size_t num_prefixes = std::min(prefixes.size(), beam_size);
size_t num_prefixes = std::min(state->prefixes.size(), beam_size);
std::sort(
prefixes.begin(), prefixes.begin() + num_prefixes, prefix_compare);
min_cutoff = prefixes[num_prefixes - 1]->score +
std::log(prob[blank_id]) - std::max(0.0, ext_scorer->beta);
state->prefixes.begin(), state->prefixes.begin() + num_prefixes, prefix_compare);

min_cutoff = state->prefixes[num_prefixes - 1]->score +
std::log(prob[state->blank_id]) - std::max(0.0, ext_scorer->beta);
full_beam = (num_prefixes == beam_size);
}

Expand All @@ -67,22 +79,25 @@ std::vector<Output> ctc_beam_search_decoder(
auto c = log_prob_idx[index].first;
auto log_prob_c = log_prob_idx[index].second;

for (size_t i = 0; i < prefixes.size() && i < beam_size; ++i) {
auto prefix = prefixes[i];
for (size_t i = 0; i < state->prefixes.size() && i < beam_size; ++i) {
auto prefix = state->prefixes[i];
if (full_beam && log_prob_c + prefix->score < min_cutoff) {
break;
}

// blank
if (c == blank_id) {
if (c == state->blank_id) {
prefix->log_prob_b_cur =
log_sum_exp(prefix->log_prob_b_cur, log_prob_c + prefix->score);
continue;
}

// repeated character
if (c == prefix->character) {
prefix->log_prob_nb_cur = log_sum_exp(
prefix->log_prob_nb_cur, log_prob_c + prefix->log_prob_nb_prev);
}

// get new prefix
auto prefix_new = prefix->get_path_trie(c, time_step, log_prob_c);

Expand All @@ -98,7 +113,7 @@ std::vector<Output> ctc_beam_search_decoder(

// language model scoring
if (ext_scorer != nullptr &&
(c == space_id || ext_scorer->is_character_based())) {
(c == state->space_id || ext_scorer->is_character_based())) {
PathTrie *prefix_to_score = nullptr;
// skip scoring the space
if (ext_scorer->is_character_based()) {
Expand All @@ -114,34 +129,41 @@ std::vector<Output> ctc_beam_search_decoder(
log_p += score;
log_p += ext_scorer->beta;
}

prefix_new->log_prob_nb_cur =
log_sum_exp(prefix_new->log_prob_nb_cur, log_p);
}
} // end of loop over prefix
} // end of loop over vocabulary


prefixes.clear();

// update log probs
root.iterate_to_vec(prefixes);
state->prefixes.clear();
state->prefix_root->iterate_to_vec(state->prefixes);

// only preserve top beam_size prefixes
if (prefixes.size() >= beam_size) {
std::nth_element(prefixes.begin(),
prefixes.begin() + beam_size,
prefixes.end(),
if (state->prefixes.size() >= beam_size) {
std::nth_element(state->prefixes.begin(),
state->prefixes.begin() + beam_size,
state->prefixes.end(),
prefix_compare);
for (size_t i = beam_size; i < prefixes.size(); ++i) {
prefixes[i]->remove();
for (size_t i = beam_size; i < state->prefixes.size(); ++i) {
state->prefixes[i]->remove();
}
}

} // end of loop over time
}

std::vector<Output> decoder_decode(DecoderState *state,
const Alphabet &alphabet,
size_t beam_size,
Scorer* ext_scorer) {

// score the last word of each prefix that doesn't end with space
if (ext_scorer != nullptr && !ext_scorer->is_character_based()) {
for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
auto prefix = prefixes[i];
if (!prefix->is_empty() && prefix->character != space_id) {
for (size_t i = 0; i < beam_size && i < state->prefixes.size(); ++i) {
auto prefix = state->prefixes[i];
if (!prefix->is_empty() && prefix->character != state->space_id) {
float score = 0.0;
std::vector<std::string> ngram = ext_scorer->make_ngram(prefix);
score = ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha;
Expand All @@ -151,30 +173,48 @@ std::vector<Output> ctc_beam_search_decoder(
}
}

size_t num_prefixes = std::min(prefixes.size(), beam_size);
std::sort(prefixes.begin(), prefixes.begin() + num_prefixes, prefix_compare);
size_t num_prefixes = std::min(state->prefixes.size(), beam_size);
std::sort(state->prefixes.begin(), state->prefixes.begin() + num_prefixes, prefix_compare);

// compute aproximate ctc score as the return score, without affecting the
// return order of decoding result. To delete when decoder gets stable.
for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
double approx_ctc = prefixes[i]->score;
for (size_t i = 0; i < beam_size && i < state->prefixes.size(); ++i) {
double approx_ctc = state->prefixes[i]->score;
if (ext_scorer != nullptr) {
std::vector<int> output;
std::vector<int> timesteps;
prefixes[i]->get_path_vec(output, timesteps);
state->prefixes[i]->get_path_vec(output, timesteps);
auto prefix_length = output.size();
auto words = ext_scorer->split_labels(output);
// remove word insert
approx_ctc = approx_ctc - prefix_length * ext_scorer->beta;
// remove language model weight:
approx_ctc -= (ext_scorer->get_sent_log_prob(words)) * ext_scorer->alpha;
}
prefixes[i]->approx_ctc = approx_ctc;
state->prefixes[i]->approx_ctc = approx_ctc;
}

return get_beam_search_result(prefixes, beam_size);
return get_beam_search_result(state->prefixes, beam_size);
}

std::vector<Output> ctc_beam_search_decoder(
const double *probs,
int time_dim,
int class_dim,
const Alphabet &alphabet,
size_t beam_size,
double cutoff_prob,
size_t cutoff_top_n,
Scorer *ext_scorer) {

DecoderState *state = decoder_init(alphabet, class_dim, ext_scorer);
decoder_next(probs, alphabet, state, time_dim, class_dim, cutoff_prob, cutoff_top_n, beam_size, ext_scorer);
std::vector<Output> out = decoder_decode(state, alphabet, beam_size, ext_scorer);

delete state;

return out;
}

std::vector<std::vector<Output>>
ctc_beam_search_decoder_batch(
Expand Down
76 changes: 68 additions & 8 deletions native_client/ctcdecode/ctc_beam_search_decoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,73 @@
#include "scorer.h"
#include "output.h"
#include "alphabet.h"
#include "decoderstate.h"

/* CTC Beam Search Decoder
/* Initialize CTC beam search decoder

* Parameters:
* alphabet: The alphabet.
* class_dim: Alphabet length (plus 1 for space character).
* ext_scorer: External scorer to evaluate a prefix, which consists of
* n-gram language model scoring and word insertion term.
* Default null, decoding the input sample without scorer.
* Return:
* A struct containing prefixes and state variables.
*/
DecoderState* decoder_init(const Alphabet &alphabet,
int class_dim,
Scorer *ext_scorer);

/* Send data to the decoder

* Parameters:
* probs: 2-D vector where each element is a vector of probabilities
* over alphabet of one time step.
* alphabet: The alphabet.
* state: The state structure previously obtained from decoder_init().
* time_dim: Number of timesteps.
* class_dim: Alphabet length (plus 1 for space character).
* cutoff_prob: Cutoff probability for pruning.
* cutoff_top_n: Cutoff number for pruning.
* beam_size: The width of beam search.
* ext_scorer: External scorer to evaluate a prefix, which consists of
* n-gram language model scoring and word insertion term.
* Default null, decoding the input sample without scorer.
*/
void decoder_next(const double *probs,
const Alphabet &alphabet,
DecoderState *state,
int time_dim,
int class_dim,
double cutoff_prob,
size_t cutoff_top_n,
size_t beam_size,
Scorer *ext_scorer);

/* Get transcription for the data you sent via decoder_next()

* Parameters:
* state: The state structure previously obtained from decoder_init().
* alphabet: The alphabet.
* beam_size: The width of beam search.
* ext_scorer: External scorer to evaluate a prefix, which consists of
* n-gram language model scoring and word insertion term.
* Default null, decoding the input sample without scorer.
* Return:
* A vector where each element is a pair of score and decoding result,
* in descending order.
*/
std::vector<Output> decoder_decode(DecoderState *state,
const Alphabet &alphabet,
size_t beam_size,
Scorer* ext_scorer);

/* CTC Beam Search Decoder
* Parameters:
* probs_seq: 2-D vector that each element is a vector of probabilities
* over alphabet of one time step.
* probs: 2-D vector where each element is a vector of probabilities
* over alphabet of one time step.
* time_dim: Number of timesteps.
* class_dim: Alphabet length (plus 1 for space character).
* alphabet: The alphabet.
* beam_size: The width of beam search.
* cutoff_prob: Cutoff probability for pruning.
Expand All @@ -21,8 +82,8 @@
* n-gram language model scoring and word insertion term.
* Default null, decoding the input sample without scorer.
* Return:
* A vector that each element is a pair of score and decoding result,
* in desending order.
* A vector where each element is a pair of score and decoding result,
* in descending order.
*/

std::vector<Output> ctc_beam_search_decoder(
Expand All @@ -36,9 +97,8 @@ std::vector<Output> ctc_beam_search_decoder(
Scorer *ext_scorer);

/* CTC Beam Search Decoder for batch data

* Parameters:
* probs_seq: 3-D vector that each element is a 2-D vector that can be used
* probs: 3-D vector where each element is a 2-D vector that can be used
* by ctc_beam_search_decoder().
* alphabet: The alphabet.
* beam_size: The width of beam search.
Expand All @@ -49,7 +109,7 @@ std::vector<Output> ctc_beam_search_decoder(
* n-gram language model scoring and word insertion term.
* Default null, decoding the input sample without scorer.
* Return:
* A 2-D vector that each element is a vector of beam search decoding
* A 2-D vector where each element is a vector of beam search decoding
* result for one audio sample.
*/
std::vector<std::vector<Output>>
Expand Down
22 changes: 22 additions & 0 deletions native_client/ctcdecode/decoderstate.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#ifndef DECODERSTATE_H_
#define DECODERSTATE_H_

#include <vector>

/* Struct for the state of the decoder, containing the prefixes and initial root prefix plus state variables. */

struct DecoderState {
int space_id;
int blank_id;
std::vector<PathTrie*> prefixes;
PathTrie *prefix_root;

~DecoderState() {
if (prefix_root != nullptr) {
delete prefix_root;
}
prefix_root = nullptr;
}
};

#endif // DECODERSTATE_H_
Loading