Skip to content

Commit

Permalink
Fix bindings of native Alphabet and usage in training code
Browse files Browse the repository at this point in the history
  • Loading branch information
reuben committed Jun 28, 2020
1 parent 7f1696a commit 7687f71
Show file tree
Hide file tree
Showing 6 changed files with 19 additions and 13 deletions.
12 changes: 8 additions & 4 deletions native_client/alphabet.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ Alphabet::Deserialize(const char* buffer, const int buffer_size)
return 0;
}

const std::string& Alphabet::DecodeSingle(unsigned int label) const {
std::string
Alphabet::DecodeSingle(unsigned int label) const {
auto it = label_to_str_.find(label);
if (it != label_to_str_.end()) {
return it->second;
Expand All @@ -109,7 +110,8 @@ const std::string& Alphabet::DecodeSingle(unsigned int label) const {
}
}

unsigned int Alphabet::EncodeSingle(const std::string& string) const {
unsigned int
Alphabet::EncodeSingle(const std::string& string) const {
auto it = str_to_label_.find(string);
if (it != str_to_label_.end()) {
return it->second;
Expand All @@ -119,15 +121,17 @@ unsigned int Alphabet::EncodeSingle(const std::string& string) const {
}
}

std::string Alphabet::Decode(const std::vector<unsigned int>& input) const {
std::string
Alphabet::Decode(const std::vector<unsigned int>& input) const {
std::string word;
for (auto ind : input) {
word += DecodeSingle(ind);
}
return word;
}

std::vector<unsigned int> Alphabet::Encode(const std::string& input) const {
std::vector<unsigned int>
Alphabet::Encode(const std::string& input) const {
std::vector<unsigned int> result;
for (auto cp : split_into_codepoints(input)) {
result.push_back(EncodeSingle(cp));
Expand Down
7 changes: 5 additions & 2 deletions native_client/alphabet.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class Alphabet {
Alphabet() = default;
Alphabet(const Alphabet&) = default;
Alphabet& operator=(const Alphabet&) = default;
virtual ~Alphabet() = default;

virtual int init(const char *config_file);

Expand All @@ -37,7 +38,7 @@ class Alphabet {
}

// Decode a single label into a string.
const std::string& DecodeSingle(unsigned int label) const;
std::string DecodeSingle(unsigned int label) const;

// Encode a single character/output class into a label.
unsigned int EncodeSingle(const std::string& string) const;
Expand Down Expand Up @@ -69,7 +70,9 @@ class UTF8Alphabet : public Alphabet
}
}

int init(const char*) override {}
int init(const char*) override {
return 0;
}
};


Expand Down
2 changes: 1 addition & 1 deletion native_client/ctcdecode/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import absolute_import, division, print_function

from . import swigwrapper # pylint: disable=import-self
from .swigwrapper import Alphabet
from .swigwrapper import UTF8Alphabet

__version__ = swigwrapper.__version__

Expand Down
3 changes: 2 additions & 1 deletion native_client/ctcdecode/build_archive.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@
'scorer.cpp',
'path_trie.cpp',
'decoder_utils.cpp',
'workspace_status.cc'
'workspace_status.cc',
'../alphabet.cc',
]

def build_archive(srcs=[], out_name='', build_dir='temp_build/temp_build', debug=False, num_parallel=1):
Expand Down
2 changes: 1 addition & 1 deletion native_client/ctcdecode/swigwrapper.i
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ namespace std {
%constant const char* __version__ = ds_version();
%constant const char* __git_version__ = ds_git_version();

%template(IntVector) std::vector<int>;
%template(UnsignedIntVector) std::vector<unsigned int>;
%template(OutputVector) std::vector<Output>;
%template(OutputVectorVector) std::vector<std::vector<Output>>;

Expand Down
6 changes: 2 additions & 4 deletions training/deepspeech_training/util/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,10 @@ def check_ctcdecoder_version():
sys.exit(1)
raise e

decoder_version_s = decoder_version.decode()

rv = semver.compare(ds_version_s, decoder_version_s)
rv = semver.compare(ds_version_s, decoder_version)
if rv != 0:
print("DeepSpeech version ({}) and CTC decoder version ({}) do not match. "
"Please ensure matching versions are in use.".format(ds_version_s, decoder_version_s))
"Please ensure matching versions are in use.".format(ds_version_s, decoder_version))
sys.exit(1)

return rv
Expand Down

0 comments on commit 7687f71

Please sign in to comment.