Skip to content

Commit

Permalink
Merge pull request #3279 from godefv/decoder_timesteps
Browse files Browse the repository at this point in the history
The CTC decoder timesteps now corresponds to the timesteps of the most probable CTC path, instead of the earliest timesteps of all possible paths.
  • Loading branch information
reuben authored Sep 17, 2020
2 parents 014479e + 188501a commit cc62aa2
Show file tree
Hide file tree
Showing 10 changed files with 277 additions and 31 deletions.
1 change: 1 addition & 0 deletions native_client/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ cc_library(
includes = [
".",
"ctcdecode/third_party/ThreadPool",
"ctcdecode/third_party/object_pool",
] + OPENFST_INCLUDES_PLATFORM,
deps = [":kenlm"],
linkopts = [
Expand Down
3 changes: 2 additions & 1 deletion native_client/ctcdecode/build_archive.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
'..',
'../kenlm',
OPENFST_DIR + '/src/include',
'third_party/ThreadPool'
'third_party/ThreadPool',
'third_party/object_pool'
]

KENLM_FILES = (glob.glob('../kenlm/util/*.cc')
Expand Down
40 changes: 36 additions & 4 deletions native_client/ctcdecode/ctc_beam_search_decoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ DecoderState::init(const Alphabet& alphabet,
PathTrie *root = new PathTrie;
root->score = root->log_prob_b_prev = 0.0;
prefix_root_.reset(root);
prefix_root_->timesteps = &timestep_tree_root_;
prefixes_.push_back(root);

if (ext_scorer && (bool)(ext_scorer_->dictionary)) {
Expand Down Expand Up @@ -96,24 +97,46 @@ DecoderState::next(const double *probs,
if (full_beam && log_prob_c + prefix->score < min_cutoff) {
break;
}
if (prefix->score == -NUM_FLT_INF) {
continue;
}
assert(prefix->timesteps != nullptr);

// blank
if (c == blank_id_) {
// compute probability of current path
float log_p = log_prob_c + prefix->score;

// combine current path with previous ones with the same prefix
// the blank label comes last, so we can compare log_prob_nb_cur with log_p
if (prefix->log_prob_nb_cur < log_p) {
// keep current timesteps
prefix->previous_timesteps = nullptr;
}
prefix->log_prob_b_cur =
log_sum_exp(prefix->log_prob_b_cur, log_prob_c + prefix->score);
log_sum_exp(prefix->log_prob_b_cur, log_p);
continue;
}

// repeated character
if (c == prefix->character) {
// compute probability of current path
float log_p = log_prob_c + prefix->log_prob_nb_prev;

// combine current path with previous ones with the same prefix
if (prefix->log_prob_nb_cur < log_p) {
// keep current timesteps
prefix->previous_timesteps = nullptr;
}
prefix->log_prob_nb_cur = log_sum_exp(
prefix->log_prob_nb_cur, log_prob_c + prefix->log_prob_nb_prev);
prefix->log_prob_nb_cur, log_p);
}

// get new prefix
auto prefix_new = prefix->get_path_trie(c, abs_time_step_, log_prob_c);
auto prefix_new = prefix->get_path_trie(c, log_prob_c);

if (prefix_new != nullptr) {
// compute probability of current path
float log_p = -NUM_FLT_INF;

if (c == prefix->character &&
Expand Down Expand Up @@ -144,6 +167,13 @@ DecoderState::next(const double *probs,
}
}

// combine current path with previous ones with the same prefix
if (prefix_new->log_prob_nb_cur < log_p) {
// record data needed to update timesteps
// the actual update will be done if nothing better is found
prefix_new->previous_timesteps = prefix->timesteps;
prefix_new->new_timestep = abs_time_step_;
}
prefix_new->log_prob_nb_cur =
log_sum_exp(prefix_new->log_prob_nb_cur, log_p);
}
Expand Down Expand Up @@ -207,7 +237,9 @@ DecoderState::decode(size_t num_results) const

for (size_t i = 0; i < num_returned; ++i) {
Output output;
prefixes_copy[i]->get_path_vec(output.tokens, output.timesteps);
prefixes_copy[i]->get_path_vec(output.tokens);
output.timesteps = get_history(prefixes_copy[i]->timesteps, &timestep_tree_root_);
assert(output.tokens.size() == output.timesteps.size());
output.confidence = scores[prefixes_copy[i]];
outputs.push_back(output);
}
Expand Down
1 change: 1 addition & 0 deletions native_client/ctcdecode/ctc_beam_search_decoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class DecoderState {
std::shared_ptr<Scorer> ext_scorer_;
std::vector<PathTrie*> prefixes_;
std::unique_ptr<PathTrie> prefix_root_;
TimestepTreeNode timestep_tree_root_{nullptr, 0};

public:
DecoderState() = default;
Expand Down
45 changes: 27 additions & 18 deletions native_client/ctcdecode/path_trie.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ PathTrie::PathTrie() {

ROOT_ = -1;
character = ROOT_;
timestep = 0;
exists_ = true;
parent = nullptr;

Expand All @@ -35,7 +34,7 @@ PathTrie::~PathTrie() {
}
}

PathTrie* PathTrie::get_path_trie(unsigned int new_char, unsigned int new_timestep, float cur_log_prob_c, bool reset) {
PathTrie* PathTrie::get_path_trie(unsigned int new_char, float cur_log_prob_c, bool reset) {
auto child = children_.begin();
for (; child != children_.end(); ++child) {
if (child->first == new_char) {
Expand Down Expand Up @@ -67,7 +66,6 @@ PathTrie* PathTrie::get_path_trie(unsigned int new_char, unsigned int new_timest
} else {
PathTrie* new_path = new PathTrie;
new_path->character = new_char;
new_path->timestep = new_timestep;
new_path->parent = this;
new_path->dictionary_ = dictionary_;
new_path->has_dictionary_ = true;
Expand All @@ -93,7 +91,6 @@ PathTrie* PathTrie::get_path_trie(unsigned int new_char, unsigned int new_timest
} else {
PathTrie* new_path = new PathTrie;
new_path->character = new_char;
new_path->timestep = new_timestep;
new_path->parent = this;
new_path->log_prob_c = cur_log_prob_c;
children_.push_back(std::make_pair(new_char, new_path));
Expand All @@ -102,20 +99,18 @@ PathTrie* PathTrie::get_path_trie(unsigned int new_char, unsigned int new_timest
}
}

void PathTrie::get_path_vec(std::vector<unsigned int>& output, std::vector<unsigned int>& timesteps) {
void PathTrie::get_path_vec(std::vector<unsigned int>& output) {
// Recursive call: recurse back until stop condition, then append data in
// correct order as we walk back down the stack in the lines below.
if (parent != nullptr) {
parent->get_path_vec(output, timesteps);
parent->get_path_vec(output);
}
if (character != ROOT_) {
output.push_back(character);
timesteps.push_back(timestep);
}
}

PathTrie* PathTrie::get_prev_grapheme(std::vector<unsigned int>& output,
std::vector<unsigned int>& timesteps,
const Alphabet& alphabet)
{
PathTrie* stop = this;
Expand All @@ -125,10 +120,9 @@ PathTrie* PathTrie::get_prev_grapheme(std::vector<unsigned int>& output,
// Recursive call: recurse back until stop condition, then append data in
// correct order as we walk back down the stack in the lines below.
if (!byte_is_codepoint_boundary(alphabet.DecodeSingle(character)[0])) {
stop = parent->get_prev_grapheme(output, timesteps, alphabet);
stop = parent->get_prev_grapheme(output, alphabet);
}
output.push_back(character);
timesteps.push_back(timestep);
return stop;
}

Expand All @@ -147,7 +141,6 @@ int PathTrie::distance_to_codepoint_boundary(unsigned char *first_byte,
}

PathTrie* PathTrie::get_prev_word(std::vector<unsigned int>& output,
std::vector<unsigned int>& timesteps,
const Alphabet& alphabet)
{
PathTrie* stop = this;
Expand All @@ -157,14 +150,18 @@ PathTrie* PathTrie::get_prev_word(std::vector<unsigned int>& output,
// Recursive call: recurse back until stop condition, then append data in
// correct order as we walk back down the stack in the lines below.
if (parent != nullptr) {
stop = parent->get_prev_word(output, timesteps, alphabet);
stop = parent->get_prev_word(output, alphabet);
}
output.push_back(character);
timesteps.push_back(timestep);
return stop;
}

void PathTrie::iterate_to_vec(std::vector<PathTrie*>& output) {
// previous_timesteps might point to ancestors' timesteps
// therefore, children must be uptaded first
for (auto child : children_) {
child.second->iterate_to_vec(output);
}
if (exists_) {
log_prob_b_prev = log_prob_b_cur;
log_prob_nb_prev = log_prob_nb_cur;
Expand All @@ -173,11 +170,23 @@ void PathTrie::iterate_to_vec(std::vector<PathTrie*>& output) {
log_prob_nb_cur = -NUM_FLT_INF;

score = log_sum_exp(log_prob_b_prev, log_prob_nb_prev);

if (previous_timesteps != nullptr) {
timesteps = nullptr;
for (auto const& child : previous_timesteps->children) {
if (child->data == new_timestep) {
timesteps = child.get();
break;
}
}
if (timesteps == nullptr) {
timesteps = add_child(previous_timesteps, new_timestep);
}
}
previous_timesteps = nullptr;

output.push_back(this);
}
for (auto child : children_) {
child.second->iterate_to_vec(output);
}
}

void PathTrie::remove() {
Expand Down Expand Up @@ -229,8 +238,8 @@ void PathTrie::print(const Alphabet& a) {
}
}
printf("\ntimesteps:\t ");
for (PathTrie* el : chain) {
printf("%d ", el->timestep);
for (unsigned int timestep : get_history(timesteps)) {
printf("%d ", timestep);
}
printf("\n");
printf("transcript:\t %s\n", tr.c_str());
Expand Down
65 changes: 60 additions & 5 deletions native_client/ctcdecode/path_trie.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,34 @@

#include "fst/fstlib.h"
#include "alphabet.h"
#include "object_pool.h"

/* Tree structure with parent and children information
* It is used to store the timesteps data for the PathTrie below
*/
template<class DataT>
struct TreeNode {
TreeNode<DataT>* parent;
std::vector<std::unique_ptr< TreeNode<DataT>, godefv::object_pool_deleter_t<TreeNode<DataT>> >> children;

DataT data;

TreeNode(TreeNode<DataT>* parent_, DataT const& data_): parent{parent_}, data{data_} {}
};

/* Creates a new TreeNode<NodeDataT> with given data as a child to the given node.
* Returns a pointer to the created node. This pointer remains valid as long as the child is not destroyed.
*/
template<class NodeDataT, class ChildDataT>
TreeNode<NodeDataT>* add_child(TreeNode<NodeDataT>* tree_node, ChildDataT&& data);

/* Returns the sequence of tree node's data from the given root (exclusive) to the given tree_node (inclusive).
* By default (if no root is provided), the full sequence from the root of the tree is returned.
*/
template<class DataT>
std::vector<DataT> get_history(TreeNode<DataT> const* tree_node, TreeNode<DataT> const* root = nullptr);

using TimestepTreeNode = TreeNode<unsigned int>;

/* Trie tree for prefix storing and manipulating, with a dictionary in
* finite-state transducer for spelling correction.
Expand All @@ -21,22 +49,20 @@ class PathTrie {
~PathTrie();

// get new prefix after appending new char
PathTrie* get_path_trie(unsigned int new_char, unsigned int new_timestep, float log_prob_c, bool reset = true);
PathTrie* get_path_trie(unsigned int new_char, float log_prob_c, bool reset = true);

// get the prefix data in correct time order from root to current node
void get_path_vec(std::vector<unsigned int>& output, std::vector<unsigned int>& timesteps);
void get_path_vec(std::vector<unsigned int>& output);

// get the prefix data in correct time order from beginning of last grapheme to current node
PathTrie* get_prev_grapheme(std::vector<unsigned int>& output,
std::vector<unsigned int>& timesteps,
const Alphabet& alphabet);

// get the distance from current node to the first codepoint boundary, and the byte value at the boundary
int distance_to_codepoint_boundary(unsigned char *first_byte, const Alphabet& alphabet);

// get the prefix data in correct time order from beginning of last word to current node
PathTrie* get_prev_word(std::vector<unsigned int>& output,
std::vector<unsigned int>& timesteps,
const Alphabet& alphabet);

// update log probs
Expand Down Expand Up @@ -65,7 +91,12 @@ class PathTrie {
float score;
float approx_ctc;
unsigned int character;
unsigned int timestep;
TimestepTreeNode* timesteps = nullptr;

// timestep temporary storage for each decoding step.
TimestepTreeNode* previous_timesteps = nullptr;
unsigned int new_timestep;

PathTrie* parent;

private:
Expand All @@ -81,4 +112,28 @@ class PathTrie {
std::shared_ptr<fst::SortedMatcher<FstType>> matcher_;
};

// TreeNode implementation
template<class NodeDataT, class ChildDataT>
TreeNode<NodeDataT>* add_child(TreeNode<NodeDataT>* tree_node, ChildDataT&& data) {
static thread_local godefv::object_pool_t<TreeNode<NodeDataT>> tree_node_pool;
tree_node->children.push_back(tree_node_pool.make_unique(tree_node, std::forward<ChildDataT>(data)));
return tree_node->children.back().get();
}

template<class DataT>
void get_history_helper(TreeNode<DataT> const* tree_node, TreeNode<DataT> const* root, std::vector<DataT>* output) {
if (tree_node == root) return;
assert(tree_node != nullptr);
assert(tree_node->parent != tree_node);
get_history_helper(tree_node->parent, root, output);
output->push_back(tree_node->data);
}
template<class DataT>
std::vector<DataT> get_history(TreeNode<DataT> const* tree_node, TreeNode<DataT> const* root) {
std::vector<DataT> output;
get_history_helper(tree_node, root, &output);
return output;
}


#endif // PATH_TRIE_H
5 changes: 2 additions & 3 deletions native_client/ctcdecode/scorer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -293,12 +293,11 @@ std::vector<std::string> Scorer::make_ngram(PathTrie* prefix)
}

std::vector<unsigned int> prefix_vec;
std::vector<unsigned int> prefix_steps;

if (is_utf8_mode_) {
new_node = current_node->get_prev_grapheme(prefix_vec, prefix_steps, alphabet_);
new_node = current_node->get_prev_grapheme(prefix_vec, alphabet_);
} else {
new_node = current_node->get_prev_word(prefix_vec, prefix_steps, alphabet_);
new_node = current_node->get_prev_word(prefix_vec, alphabet_);
}
current_node = new_node->parent;

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
This code was imported from https://github.com/godefv/memory on September 17th 2020, commit 5ff1af8ee09ced04990b4863b2c02a8d07f4356a. It's licensed under "CC0 1.0 Universal" license.
Loading

0 comments on commit cc62aa2

Please sign in to comment.