From 700572e4087aafe5f86aac34ab00260c05a3cc9c Mon Sep 17 00:00:00 2001 From: dabinat Date: Mon, 20 May 2019 23:16:18 -0700 Subject: [PATCH] Restored old CTC decoder API --- .../ctcdecode/ctc_beam_search_decoder.cpp | 59 +++++++++++++++++++ .../ctcdecode/ctc_beam_search_decoder.h | 59 +++++++++++++++++++ 2 files changed, 118 insertions(+) diff --git a/native_client/ctcdecode/ctc_beam_search_decoder.cpp b/native_client/ctcdecode/ctc_beam_search_decoder.cpp index 37cb507ba4..165f8154fd 100644 --- a/native_client/ctcdecode/ctc_beam_search_decoder.cpp +++ b/native_client/ctcdecode/ctc_beam_search_decoder.cpp @@ -196,3 +196,62 @@ std::vector decoder_decode(DecoderState *state, return get_beam_search_result(state->prefixes, beam_size); } + + +std::vector ctc_beam_search_decoder( + const double *probs, + int time_dim, + int class_dim, + const Alphabet &alphabet, + size_t beam_size, + double cutoff_prob, + size_t cutoff_top_n, + Scorer *ext_scorer) { + + DecoderState *state = decoder_init(alphabet, class_dim, ext_scorer); + + decoder_next(probs, alphabet, state, time_dim, class_dim, cutoff_prob, cutoff_top_n, beam_size, ext_scorer); + + return decoder_decode(state, alphabet, beam_size, ext_scorer); +} + +std::vector> +ctc_beam_search_decoder_batch( + const double *probs, + int batch_size, + int time_dim, + int class_dim, + const int* seq_lengths, + int seq_lengths_size, + const Alphabet &alphabet, + size_t beam_size, + size_t num_processes, + double cutoff_prob, + size_t cutoff_top_n, + Scorer *ext_scorer) { + VALID_CHECK_GT(num_processes, 0, "num_processes must be nonnegative!"); + VALID_CHECK_EQ(batch_size, seq_lengths_size, "must have one sequence length per batch element"); + // thread pool + ThreadPool pool(num_processes); + + // enqueue the tasks of decoding + std::vector>> res; + for (size_t i = 0; i < batch_size; ++i) { + res.emplace_back(pool.enqueue(ctc_beam_search_decoder, + &probs[i*time_dim*class_dim], + seq_lengths[i], + class_dim, + alphabet, + beam_size, + cutoff_prob, + cutoff_top_n, + ext_scorer)); + } + + // get decoding results + std::vector> batch_results; + for (size_t i = 0; i < batch_size; ++i) { + batch_results.emplace_back(res[i].get()); + } + return batch_results; +} \ No newline at end of file diff --git a/native_client/ctcdecode/ctc_beam_search_decoder.h b/native_client/ctcdecode/ctc_beam_search_decoder.h index 7e83720c16..81f1b61372 100644 --- a/native_client/ctcdecode/ctc_beam_search_decoder.h +++ b/native_client/ctcdecode/ctc_beam_search_decoder.h @@ -68,4 +68,63 @@ std::vector decoder_decode(DecoderState *state, size_t beam_size, Scorer* ext_scorer); +/* CTC Beam Search Decoder + * Parameters: + * probs: 2-D vector where each element is a vector of probabilities + * over alphabet of one time step. + * time_dim: Number of timesteps. + * class_dim: Alphabet length (plus 1 for space character). + * alphabet: The alphabet. + * beam_size: The width of beam search. + * cutoff_prob: Cutoff probability for pruning. + * cutoff_top_n: Cutoff number for pruning. + * ext_scorer: External scorer to evaluate a prefix, which consists of + * n-gram language model scoring and word insertion term. + * Default null, decoding the input sample without scorer. + * Return: + * A vector where each element is a pair of score and decoding result, + * in descending order. +*/ + +std::vector ctc_beam_search_decoder( + const double* probs, + int time_dim, + int class_dim, + const Alphabet &alphabet, + size_t beam_size, + double cutoff_prob, + size_t cutoff_top_n, + Scorer *ext_scorer); + +/* CTC Beam Search Decoder for batch data + * Parameters: + * probs: 3-D vector where each element is a 2-D vector that can be used + * by ctc_beam_search_decoder(). + * alphabet: The alphabet. + * beam_size: The width of beam search. + * num_processes: Number of threads for beam search. + * cutoff_prob: Cutoff probability for pruning. + * cutoff_top_n: Cutoff number for pruning. + * ext_scorer: External scorer to evaluate a prefix, which consists of + * n-gram language model scoring and word insertion term. + * Default null, decoding the input sample without scorer. + * Return: + * A 2-D vector where each element is a vector of beam search decoding + * result for one audio sample. +*/ +std::vector> +ctc_beam_search_decoder_batch( + const double* probs, + int batch_size, + int time_dim, + int class_dim, + const int* seq_lengths, + int seq_lengths_size, + const Alphabet &alphabet, + size_t beam_size, + size_t num_processes, + double cutoff_prob, + size_t cutoff_top_n, + Scorer *ext_scorer); + #endif // CTC_BEAM_SEARCH_DECODER_H_