diff --git a/native_client/alphabet.cc b/native_client/alphabet.cc index 873b4881be..1f0a8dbea2 100644 --- a/native_client/alphabet.cc +++ b/native_client/alphabet.cc @@ -3,6 +3,41 @@ #include +// std::getline, but handle newline conventions from multiple platforms instead +// of just the platform this code was built for +std::istream& +getline_crossplatform(std::istream& is, std::string& t) +{ + t.clear(); + + // The characters in the stream are read one-by-one using a std::streambuf. + // That is faster than reading them one-by-one using the std::istream. + // Code that uses streambuf this way must be guarded by a sentry object. + // The sentry object performs various tasks, + // such as thread synchronization and updating the stream state. + std::istream::sentry se(is, true); + std::streambuf* sb = is.rdbuf(); + + while (true) { + int c = sb->sbumpc(); + switch (c) { + case '\n': + return is; + case '\r': + if(sb->sgetc() == '\n') + sb->sbumpc(); + return is; + case std::streambuf::traits_type::eof(): + // Also handle the case when the last line has no line ending + if(t.empty()) + is.setstate(std::ios::eofbit); + return is; + default: + t += (char)c; + } + } +} + int Alphabet::init(const char *config_file) { @@ -12,7 +47,7 @@ Alphabet::init(const char *config_file) } unsigned int label = 0; space_label_ = -2; - for (std::string line; std::getline(in, line);) { + for (std::string line; getline_crossplatform(in, line);) { if (line.size() == 2 && line[0] == '\\' && line[1] == '#') { line = '#'; } else if (line[0] == '#') { diff --git a/native_client/ctcdecode/__init__.py b/native_client/ctcdecode/__init__.py index ac603aa9ba..ee5645d408 100644 --- a/native_client/ctcdecode/__init__.py +++ b/native_client/ctcdecode/__init__.py @@ -45,6 +45,11 @@ def __init__(self, config_path): if err != 0: raise ValueError('Alphabet initialization failed with error code 0x{:X}'.format(err)) + def Encode(self, input): + """Convert SWIG's UnsignedIntVec to a Python list""" + res = super(Alphabet, self).Encode(input) + return [el for el in res] + def ctc_beam_search_decoder(probs_seq, alphabet,