Skip to content

Commit

Permalink
Restored old CTC decoder API
Browse files Browse the repository at this point in the history
  • Loading branch information
dabinat committed May 21, 2019
1 parent 5aaf75d commit 700572e
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 0 deletions.
59 changes: 59 additions & 0 deletions native_client/ctcdecode/ctc_beam_search_decoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,3 +196,62 @@ std::vector<Output> decoder_decode(DecoderState *state,

return get_beam_search_result(state->prefixes, beam_size);
}


std::vector<Output> ctc_beam_search_decoder(
const double *probs,
int time_dim,
int class_dim,
const Alphabet &alphabet,
size_t beam_size,
double cutoff_prob,
size_t cutoff_top_n,
Scorer *ext_scorer) {

DecoderState *state = decoder_init(alphabet, class_dim, ext_scorer);

decoder_next(probs, alphabet, state, time_dim, class_dim, cutoff_prob, cutoff_top_n, beam_size, ext_scorer);

return decoder_decode(state, alphabet, beam_size, ext_scorer);
}

std::vector<std::vector<Output>>
ctc_beam_search_decoder_batch(
const double *probs,
int batch_size,
int time_dim,
int class_dim,
const int* seq_lengths,
int seq_lengths_size,
const Alphabet &alphabet,
size_t beam_size,
size_t num_processes,
double cutoff_prob,
size_t cutoff_top_n,
Scorer *ext_scorer) {
VALID_CHECK_GT(num_processes, 0, "num_processes must be nonnegative!");
VALID_CHECK_EQ(batch_size, seq_lengths_size, "must have one sequence length per batch element");
// thread pool
ThreadPool pool(num_processes);

// enqueue the tasks of decoding
std::vector<std::future<std::vector<Output>>> res;
for (size_t i = 0; i < batch_size; ++i) {
res.emplace_back(pool.enqueue(ctc_beam_search_decoder,
&probs[i*time_dim*class_dim],
seq_lengths[i],
class_dim,
alphabet,
beam_size,
cutoff_prob,
cutoff_top_n,
ext_scorer));
}

// get decoding results
std::vector<std::vector<Output>> batch_results;
for (size_t i = 0; i < batch_size; ++i) {
batch_results.emplace_back(res[i].get());
}
return batch_results;
}
59 changes: 59 additions & 0 deletions native_client/ctcdecode/ctc_beam_search_decoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,63 @@ std::vector<Output> decoder_decode(DecoderState *state,
size_t beam_size,
Scorer* ext_scorer);

/* CTC Beam Search Decoder
* Parameters:
* probs: 2-D vector where each element is a vector of probabilities
* over alphabet of one time step.
* time_dim: Number of timesteps.
* class_dim: Alphabet length (plus 1 for space character).
* alphabet: The alphabet.
* beam_size: The width of beam search.
* cutoff_prob: Cutoff probability for pruning.
* cutoff_top_n: Cutoff number for pruning.
* ext_scorer: External scorer to evaluate a prefix, which consists of
* n-gram language model scoring and word insertion term.
* Default null, decoding the input sample without scorer.
* Return:
* A vector where each element is a pair of score and decoding result,
* in descending order.
*/

std::vector<Output> ctc_beam_search_decoder(
const double* probs,
int time_dim,
int class_dim,
const Alphabet &alphabet,
size_t beam_size,
double cutoff_prob,
size_t cutoff_top_n,
Scorer *ext_scorer);

/* CTC Beam Search Decoder for batch data
* Parameters:
* probs: 3-D vector where each element is a 2-D vector that can be used
* by ctc_beam_search_decoder().
* alphabet: The alphabet.
* beam_size: The width of beam search.
* num_processes: Number of threads for beam search.
* cutoff_prob: Cutoff probability for pruning.
* cutoff_top_n: Cutoff number for pruning.
* ext_scorer: External scorer to evaluate a prefix, which consists of
* n-gram language model scoring and word insertion term.
* Default null, decoding the input sample without scorer.
* Return:
* A 2-D vector where each element is a vector of beam search decoding
* result for one audio sample.
*/
std::vector<std::vector<Output>>
ctc_beam_search_decoder_batch(
const double* probs,
int batch_size,
int time_dim,
int class_dim,
const int* seq_lengths,
int seq_lengths_size,
const Alphabet &alphabet,
size_t beam_size,
size_t num_processes,
double cutoff_prob,
size_t cutoff_top_n,
Scorer *ext_scorer);

#endif // CTC_BEAM_SEARCH_DECODER_H_

0 comments on commit 700572e

Please sign in to comment.