Skip to content

Commit

Permalink
CTC beam search streaming decoder
Browse files Browse the repository at this point in the history
Backported from PR mozilla#2121

Signed-off-by: Li Li <eggonlea@msn.com>
  • Loading branch information
eggonlea committed May 20, 2019
1 parent 9fe9827 commit f70f737
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 144 deletions.
152 changes: 66 additions & 86 deletions native_client/ctcdecode/ctc_beam_search_decoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,49 +14,61 @@

using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;

std::vector<Output> ctc_beam_search_decoder(
const double *probs,
int time_dim,
int class_dim,
const Alphabet &alphabet,
size_t beam_size,
double cutoff_prob,
size_t cutoff_top_n,
Scorer *ext_scorer) {
DecoderState* decoder_init(const Alphabet &alphabet,
int class_dim,
Scorer* ext_scorer) {

// dimension check
VALID_CHECK_EQ(class_dim, alphabet.GetSize()+1,
"The shape of probs does not match with "
"the shape of the vocabulary");

// assign special ids
int space_id = alphabet.GetSpaceLabel();
int blank_id = alphabet.GetSize();
DecoderState *state = new DecoderState;
state->space_id = alphabet.GetSpaceLabel();
state->blank_id = alphabet.GetSize();

// init prefixes' root
PathTrie root;
root.score = root.log_prob_b_prev = 0.0;
std::vector<PathTrie *> prefixes;
prefixes.push_back(&root);
PathTrie *root = new PathTrie;
root->score = root->log_prob_b_prev = 0.0;

state->prefix_root = root;

state->prefixes.push_back(root);

if (ext_scorer != nullptr && !ext_scorer->is_character_based()) {
auto dict_ptr = ext_scorer->dictionary->Copy(true);
root.set_dictionary(dict_ptr);
root->set_dictionary(dict_ptr);
auto matcher = std::make_shared<FSTMATCH>(*dict_ptr, fst::MATCH_INPUT);
root.set_matcher(matcher);
root->set_matcher(matcher);
}

return state;
}

void decoder_next(const double *probs,
const Alphabet &alphabet,
DecoderState *state,
int time_dim,
int class_dim,
double cutoff_prob,
size_t cutoff_top_n,
size_t beam_size,
Scorer *ext_scorer) {

// prefix search over time
for (size_t time_step = 0; time_step < time_dim; ++time_step) {
auto *prob = &probs[time_step*class_dim];

float min_cutoff = -NUM_FLT_INF;
bool full_beam = false;
if (ext_scorer != nullptr) {
size_t num_prefixes = std::min(prefixes.size(), beam_size);
size_t num_prefixes = std::min(state->prefixes.size(), beam_size);
std::sort(
prefixes.begin(), prefixes.begin() + num_prefixes, prefix_compare);
min_cutoff = prefixes[num_prefixes - 1]->score +
std::log(prob[blank_id]) - std::max(0.0, ext_scorer->beta);
state->prefixes.begin(), state->prefixes.begin() + num_prefixes, prefix_compare);

min_cutoff = state->prefixes[num_prefixes - 1]->score +
std::log(prob[state->blank_id]) - std::max(0.0, ext_scorer->beta);
full_beam = (num_prefixes == beam_size);
}

Expand All @@ -67,22 +79,25 @@ std::vector<Output> ctc_beam_search_decoder(
auto c = log_prob_idx[index].first;
auto log_prob_c = log_prob_idx[index].second;

for (size_t i = 0; i < prefixes.size() && i < beam_size; ++i) {
auto prefix = prefixes[i];
for (size_t i = 0; i < state->prefixes.size() && i < beam_size; ++i) {
auto prefix = state->prefixes[i];
if (full_beam && log_prob_c + prefix->score < min_cutoff) {
break;
}

// blank
if (c == blank_id) {
if (c == state->blank_id) {
prefix->log_prob_b_cur =
log_sum_exp(prefix->log_prob_b_cur, log_prob_c + prefix->score);
continue;
}

// repeated character
if (c == prefix->character) {
prefix->log_prob_nb_cur = log_sum_exp(
prefix->log_prob_nb_cur, log_prob_c + prefix->log_prob_nb_prev);
}

// get new prefix
auto prefix_new = prefix->get_path_trie(c, time_step);

Expand All @@ -98,7 +113,7 @@ std::vector<Output> ctc_beam_search_decoder(

// language model scoring
if (ext_scorer != nullptr &&
(c == space_id || ext_scorer->is_character_based())) {
(c == state->space_id || ext_scorer->is_character_based())) {
PathTrie *prefix_to_score = nullptr;
// skip scoring the space
if (ext_scorer->is_character_based()) {
Expand All @@ -114,34 +129,41 @@ std::vector<Output> ctc_beam_search_decoder(
log_p += score;
log_p += ext_scorer->beta;
}

prefix_new->log_prob_nb_cur =
log_sum_exp(prefix_new->log_prob_nb_cur, log_p);
}
} // end of loop over prefix
} // end of loop over vocabulary


prefixes.clear();
// update log probs
root.iterate_to_vec(prefixes);
state->prefixes.clear();
state->prefix_root->iterate_to_vec(state->prefixes);

// only preserve top beam_size prefixes
if (prefixes.size() >= beam_size) {
std::nth_element(prefixes.begin(),
prefixes.begin() + beam_size,
prefixes.end(),
if (state->prefixes.size() >= beam_size) {
std::nth_element(state->prefixes.begin(),
state->prefixes.begin() + beam_size,
state->prefixes.end(),
prefix_compare);
for (size_t i = beam_size; i < prefixes.size(); ++i) {
prefixes[i]->remove();
for (size_t i = beam_size; i < state->prefixes.size(); ++i) {
state->prefixes[i]->remove();
}
}

} // end of loop over time
}

std::vector<Output> decoder_decode(DecoderState *state,
const Alphabet &alphabet,
size_t beam_size,
Scorer* ext_scorer) {

// score the last word of each prefix that doesn't end with space
if (ext_scorer != nullptr && !ext_scorer->is_character_based()) {
for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
auto prefix = prefixes[i];
if (!prefix->is_empty() && prefix->character != space_id) {
for (size_t i = 0; i < beam_size && i < state->prefixes.size(); ++i) {
auto prefix = state->prefixes[i];
if (!prefix->is_empty() && prefix->character != state->space_id) {
float score = 0.0;
std::vector<std::string> ngram = ext_scorer->make_ngram(prefix);
score = ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha;
Expand All @@ -151,68 +173,26 @@ std::vector<Output> ctc_beam_search_decoder(
}
}

size_t num_prefixes = std::min(prefixes.size(), beam_size);
std::sort(prefixes.begin(), prefixes.begin() + num_prefixes, prefix_compare);
size_t num_prefixes = std::min(state->prefixes.size(), beam_size);
std::sort(state->prefixes.begin(), state->prefixes.begin() + num_prefixes, prefix_compare);

// compute aproximate ctc score as the return score, without affecting the
// return order of decoding result. To delete when decoder gets stable.
for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
double approx_ctc = prefixes[i]->score;
for (size_t i = 0; i < beam_size && i < state->prefixes.size(); ++i) {
double approx_ctc = state->prefixes[i]->score;
if (ext_scorer != nullptr) {
std::vector<int> output;
std::vector<int> timesteps;
prefixes[i]->get_path_vec(output, timesteps);
state->prefixes[i]->get_path_vec(output, timesteps);
auto prefix_length = output.size();
auto words = ext_scorer->split_labels(output);
// remove word insert
approx_ctc = approx_ctc - prefix_length * ext_scorer->beta;
// remove language model weight:
approx_ctc -= (ext_scorer->get_sent_log_prob(words)) * ext_scorer->alpha;
}
prefixes[i]->approx_ctc = approx_ctc;
}

return get_beam_search_result(prefixes, beam_size);
}


std::vector<std::vector<Output>>
ctc_beam_search_decoder_batch(
const double *probs,
int batch_size,
int time_dim,
int class_dim,
const int* seq_lengths,
int seq_lengths_size,
const Alphabet &alphabet,
size_t beam_size,
size_t num_processes,
double cutoff_prob,
size_t cutoff_top_n,
Scorer *ext_scorer) {
VALID_CHECK_GT(num_processes, 0, "num_processes must be nonnegative!");
VALID_CHECK_EQ(batch_size, seq_lengths_size, "must have one sequence length per batch element");
// thread pool
ThreadPool pool(num_processes);

// enqueue the tasks of decoding
std::vector<std::future<std::vector<Output>>> res;
for (size_t i = 0; i < batch_size; ++i) {
res.emplace_back(pool.enqueue(ctc_beam_search_decoder,
&probs[i*time_dim*class_dim],
seq_lengths[i],
class_dim,
alphabet,
beam_size,
cutoff_prob,
cutoff_top_n,
ext_scorer));
state->prefixes[i]->approx_ctc = approx_ctc;
}

// get decoding results
std::vector<std::vector<Output>> batch_results;
for (size_t i = 0; i < batch_size; ++i) {
batch_results.emplace_back(res[i].get());
}
return batch_results;
return get_beam_search_result(state->prefixes, beam_size);
}
76 changes: 39 additions & 37 deletions native_client/ctcdecode/ctc_beam_search_decoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,64 +7,66 @@
#include "scorer.h"
#include "output.h"
#include "alphabet.h"
#include "decoderstate.h"

/* CTC Beam Search Decoder
/* Initialize CTC beam search decoder
* Parameters:
* probs_seq: 2-D vector that each element is a vector of probabilities
* alphabet: The alphabet.
* class_dim: Alphabet length (plus 1 for space character).
* ext_scorer: External scorer to evaluate a prefix, which consists of
* n-gram language model scoring and word insertion term.
* Default null, decoding the input sample without scorer.
* Return:
* A struct containing word prefixes and state variables.
*/
DecoderState* decoder_init(const Alphabet &alphabet,
int class_dim,
Scorer *ext_scorer);

/* Send data to the decoder
* Parameters:
* probs: 2-D vector where each element is a vector of probabilities
* over alphabet of one time step.
* alphabet: The alphabet.
* beam_size: The width of beam search.
* state: The state structure previously obtained from decoder_init().
* time_dim: Number of timesteps.
* class_dim: Alphabet length (plus 1 for space character).
* cutoff_prob: Cutoff probability for pruning.
* cutoff_top_n: Cutoff number for pruning.
* beam_size: The width of beam search.
* ext_scorer: External scorer to evaluate a prefix, which consists of
* n-gram language model scoring and word insertion term.
* Default null, decoding the input sample without scorer.
* Return:
* A vector that each element is a pair of score and decoding result,
* in desending order.
* A struct containing word prefixes and state variables.
*/
void decoder_next(const double *probs,
const Alphabet &alphabet,
DecoderState *state,
int time_dim,
int class_dim,
double cutoff_prob,
size_t cutoff_top_n,
size_t beam_size,
Scorer *ext_scorer);

std::vector<Output> ctc_beam_search_decoder(
const double* probs,
int time_dim,
int class_dim,
const Alphabet &alphabet,
size_t beam_size,
double cutoff_prob,
size_t cutoff_top_n,
Scorer *ext_scorer);

/* CTC Beam Search Decoder for batch data
/* Get transcription for the data you sent via decoder_next()
* Parameters:
* probs_seq: 3-D vector that each element is a 2-D vector that can be used
* by ctc_beam_search_decoder().
* state: The state structure previously obtained from decoder_init().
* alphabet: The alphabet.
* beam_size: The width of beam search.
* num_processes: Number of threads for beam search.
* cutoff_prob: Cutoff probability for pruning.
* cutoff_top_n: Cutoff number for pruning.
* ext_scorer: External scorer to evaluate a prefix, which consists of
* n-gram language model scoring and word insertion term.
* Default null, decoding the input sample without scorer.
* Return:
* A 2-D vector that each element is a vector of beam search decoding
* result for one audio sample.
* A struct containing word prefixes and state variables.
*/
std::vector<std::vector<Output>>
ctc_beam_search_decoder_batch(
const double* probs,
int batch_size,
int time_dim,
int class_dim,
const int* seq_lengths,
int seq_lengths_size,
const Alphabet &alphabet,
size_t beam_size,
size_t num_processes,
double cutoff_prob,
size_t cutoff_top_n,
Scorer *ext_scorer);
std::vector<Output> decoder_decode(DecoderState *state,
const Alphabet &alphabet,
size_t beam_size,
Scorer* ext_scorer);

#endif // CTC_BEAM_SEARCH_DECODER_H_
16 changes: 16 additions & 0 deletions native_client/ctcdecode/decoderstate.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#ifndef DECODERSTATE_H_
#define DECODERSTATE_H_

#include <vector>

/* Struct for the beam search output, containing the tokens based on the vocabulary indices, and the timesteps
* for each token in the beam search output
*/
struct DecoderState {
int space_id;
int blank_id;
std::vector<PathTrie*> prefixes;
PathTrie *prefix_root;
};

#endif // DECODERSTATE_H_
Loading

0 comments on commit f70f737

Please sign in to comment.