Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for DeepSpeech 0.7.x (WIP) #32

Merged
merged 7 commits into from
Jul 3, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 43 additions & 52 deletions align/align.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from text import Alphabet, TextCleaner, levenshtein, similarity
from utils import enweight, log_progress
from audio import DEFAULT_RATE, read_frames_from_file, vad_split
from generate_lm import convert_and_filter_topk, build_lm
from generate_package import create_bundle

BEAM_WIDTH = 500
LM_ALPHA = 1
Expand Down Expand Up @@ -46,7 +48,7 @@ def read_script(script_path):
dashes_to_ws=not args.text_keep_dashes,
normalize_space=not args.text_keep_ws,
to_lower=not args.text_keep_casing)
with open(script_path, 'r') as script_file:
with open(script_path, 'r', encoding='utf-8') as script_file:
content = script_file.read()
if script_path.endswith('.script'):
for phrase in json.loads(content):
Expand All @@ -61,10 +63,10 @@ def read_script(script_path):

model = None

def init_stt(output_graph_path, lm_path, trie_path):
def init_stt(output_graph_path, scorer_path):
global model
model = deepspeech.Model(output_graph_path, BEAM_WIDTH)
model.enableDecoderWithLM(lm_path, trie_path, LM_ALPHA, LM_BETA)
model = deepspeech.Model(output_graph_path)
model.enableExternalScorer(scorer_path)
logging.debug('Process {}: Loaded models'.format(os.getpid()))


Expand All @@ -89,7 +91,7 @@ def align(triple):
gap_score=args.align_gap_score)

logging.debug("Loading transcription log from %s..." % tlog)
with open(tlog, 'r') as transcription_log_file:
with open(tlog, 'r', encoding='utf-8') as transcription_log_file:
fragments = json.load(transcription_log_file)
end_fragments = (args.start + args.num_samples) if args.num_samples else len(fragments)
fragments = fragments[args.start:end_fragments]
Expand Down Expand Up @@ -349,8 +351,8 @@ def apply_number(number_key, index, fragment, show, get_value):
'trim',
str(time_start / 1000.0),
'=' + str(time_end / 1000.0)])
with open(aligned, 'w') as result_file:
result_file.write(json.dumps(result_fragments, indent=4 if args.output_pretty else None))
with open(aligned, 'w', encoding='utf-8') as result_file:
result_file.write(json.dumps(result_fragments, indent=4 if args.output_pretty else None, ensure_ascii=False))
return aligned, len(result_fragments), len(fragments) - len(result_fragments), reasons


Expand Down Expand Up @@ -401,7 +403,7 @@ def enqueue_or_fail(audio, tlog, script, aligned, prefix=''):
fail('Unable to load catalog file "{}"'.format(args.catalog))
catalog = path.abspath(args.catalog)
catalog_dir = path.dirname(catalog)
with open(catalog, 'r') as catalog_file:
with open(catalog, 'r', encoding='utf-8') as catalog_file:
catalog_entries = json.load(catalog_file)
for entry in progress(catalog_entries, desc='Reading catalog'):
enqueue_or_fail(resolve(catalog_dir, entry['audio']),
Expand All @@ -418,11 +420,11 @@ def enqueue_or_fail(audio, tlog, script, aligned, prefix=''):
output_graph_path = None
for audio_path, tlog_path, script_path, aligned_path in to_prepare:
if not exists(tlog_path):
generated_scorer = False
if output_graph_path is None:
logging.debug('Looking for model files in "{}"...'.format(model_dir))
output_graph_path = glob(model_dir + "/output_graph.pb")[0]
lang_lm_path = glob(model_dir + "/lm.binary")[0]
lang_trie_path = glob(model_dir + "/trie")[0]
output_graph_path = glob(model_dir + "/*.pbmm")[0]
lang_scorer_path = glob(model_dir + "/*.scorer")[0]
kenlm_path = 'dependencies/kenlm/build/bin'
if not path.exists(kenlm_path):
kenlm_path = None
Expand All @@ -435,45 +437,30 @@ def enqueue_or_fail(audio, tlog, script, aligned, prefix=''):
logging.error('Cleaned transcript is empty for {}'.format(path.basename(script_path)))
continue
clean_text_path = script_path + '.clean'
with open(clean_text_path, 'w') as clean_text_file:
with open(clean_text_path, 'w', encoding='utf-8') as clean_text_file:
clean_text_file.write(tc.clean_text)

arpa_path = script_path + '.arpa'
if not path.exists(arpa_path):
subprocess.check_call([
kenlm_path + '/lmplz',
'--discount_fallback',
'--text',
clean_text_path,
'--arpa',
arpa_path,
'--o',
'5'
])

lm_path = script_path + '.lm'
if not path.exists(lm_path):
subprocess.check_call([
kenlm_path + '/build_binary',
'-s',
arpa_path,
lm_path
])

trie_path = script_path + '.trie'
if not path.exists(trie_path):
subprocess.check_call([
deepspeech_path + '/generate_trie',
alphabet_path,
lm_path,
trie_path
])
scorer_path = script_path + '.scorer'
if not path.exists(scorer_path):
# Generate LM
data_lower, vocab_str = convert_and_filter_topk(scorer_path, clean_text_path, 500000)
build_lm(scorer_path, kenlm_path, 5, '85%', '0|0|1', True, 255, 8, 'trie', data_lower, vocab_str)
os.remove(scorer_path + '.' + 'lower.txt.gz')
os.remove(scorer_path + '.' + 'lm.arpa')
os.remove(scorer_path + '.' + 'lm_filtered.arpa')
os.remove(clean_text_path)

# Generate scorer
create_bundle(alphabet_path, scorer_path + '.' + 'lm.binary', scorer_path + '.' + 'vocab-500000.txt', scorer_path, False, 0.931289039105002, 1.1834137581510284)
os.remove(scorer_path + '.' + 'lm.binary')
os.remove(scorer_path + '.' + 'vocab-500000.txt')

generated_scorer = True
else:
lm_path = lang_lm_path
trie_path = lang_trie_path
scorer_path = lang_scorer_path

logging.debug('Loading acoustic model from "{}", alphabet from "{}", trie from "{}" and language model from "{}"...'
.format(output_graph_path, alphabet_path, trie_path, lm_path))
logging.debug('Loading acoustic model from "{}", alphabet from "{}" and scorer from "{}"...'
.format(output_graph_path, alphabet_path, scorer_path))

# Run VAD on the input file
logging.debug('Transcribing VAD segments...')
Expand All @@ -499,7 +486,7 @@ def pre_filter():
samples = list(progress(pre_filter(), desc='VAD splitting'))

pool = multiprocessing.Pool(initializer=init_stt,
initargs=(output_graph_path, lm_path, trie_path),
initargs=(output_graph_path, scorer_path),
processes=args.stt_workers)
transcripts = list(progress(pool.imap(stt, samples), desc='Transcribing', total=len(samples)))

Expand All @@ -515,8 +502,12 @@ def pre_filter():
logging.debug('Excluded {} empty transcripts'.format(len(transcripts) - len(fragments)))

logging.debug('Writing transcription log to file "{}"...'.format(tlog_path))
with open(tlog_path, 'w') as tlog_file:
tlog_file.write(json.dumps(fragments, indent=4 if args.output_pretty else None))
with open(tlog_path, 'w', encoding='utf-8') as tlog_file:
tlog_file.write(json.dumps(fragments, indent=4 if args.output_pretty else None, ensure_ascii=False))

# Remove scorer if generated
if generated_scorer:
os.remove(scorer_path)
if not path.isfile(tlog_path):
fail('Problem loading transcript from "{}"'.format(tlog_path))
to_align.append((tlog_path, script_path, aligned_path))
Expand Down Expand Up @@ -595,13 +586,13 @@ def parse_args():
stt_group.add_argument('--stt-model-rate', type=int, default=DEFAULT_RATE,
help='Supported sample rate of the acoustic model')
stt_group.add_argument('--stt-model-dir', required=False,
help='Path to a directory with output_graph, lm, trie and (optional) alphabet file ' +
'(default: "data/en"')
help='Path to a directory with output_graph, scorer and (optional) alphabet file ' +
'(default: "models/en"')
stt_group.add_argument('--stt-no-own-lm', action="store_true",
help='Deactivates creation of individual language models per document.' +
'Uses the one from model dir instead.')
stt_group.add_argument('--stt-workers', type=int, required=False, default=1,
help='Number of parallel STT workers - should 1 for GPU based DeepSpeech')
help='Number of parallel STT workers - should be 1 for GPU based DeepSpeech')
stt_group.add_argument('--stt-min-duration', type=int, required=False, default=100,
help='Minimum speech fragment duration in milliseconds to translate (default: 100)')
stt_group.add_argument('--stt-max-duration', type=int, required=False,
Expand Down
8 changes: 7 additions & 1 deletion align/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ def ensure_wav_with_format(src_audio_path, audio_format=DEFAULT_FORMAT, tmp_dir=
return src_audio_path, False
fd, tmp_file_path = tempfile.mkstemp(suffix='.wav', dir=tmp_dir)
os.close(fd)
fd = None
if convert_audio(src_audio_path, tmp_file_path, file_type='wav', audio_format=audio_format):
return tmp_file_path, True
os.remove(tmp_file_path)
Expand Down Expand Up @@ -185,7 +186,9 @@ def __enter__(self):
return self.audio_path
return self.open_file
self.open_file.close()
_, self.tmp_file_path = tempfile.mkstemp(suffix='.wav')
test, self.tmp_file_path = tempfile.mkstemp(suffix='.wav')
os.close(test)
test = None
if not convert_audio(self.audio_path, self.tmp_file_path, file_type='wav', audio_format=self.audio_format):
raise RuntimeError('Unable to convert "{}" to required format'.format(self.audio_path))
if self.as_path:
Expand All @@ -194,10 +197,12 @@ def __enter__(self):
return self.open_file

def __exit__(self, *args):
self.open_file.close()
if not self.as_path:
self.open_file.close()
if self.tmp_file_path is not None:
os.remove(self.tmp_file_path)
self.open_file = None


def read_frames(wav_file, frame_duration_ms=30, yield_remainder=False):
Expand Down Expand Up @@ -337,6 +342,7 @@ def read_wav(wav_file):
with wave.open(wav_file, 'rb') as wav_file_reader:
audio_format = read_audio_format_from_wav_file(wav_file_reader)
pcm_data = wav_file_reader.readframes(wav_file_reader.getnframes())
os.close(wav_file)
return audio_format, pcm_data


Expand Down
4 changes: 2 additions & 2 deletions align/catalog_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def build_catalog():
print('Loading catalog "{}"'.format(str(catalog_original_path)))
if not catalog_path.is_file():
fail('Unable to find catalog file "{}"'.format(str(catalog_path)))
with open(catalog_path, 'r') as catalog_file:
with open(catalog_path, 'r', encoding='utf-8') as catalog_file:
catalog_items = json.load(catalog_file)
base_path = catalog_path.parent.absolute()
for item in catalog_items:
Expand Down Expand Up @@ -62,7 +62,7 @@ def build_catalog():
item[entry] = str(Path(item[entry]).relative_to(base_path))
if CLI_ARGS.order_by is not None:
items.sort(key=lambda i: i[CLI_ARGS.order_by] if CLI_ARGS.order_by in i else '')
with open(catalog_path, 'w') as catalog_file:
with open(catalog_path, 'w', encoding='utf-8') as catalog_file:
json.dump(items, catalog_file, indent=2)


Expand Down
6 changes: 3 additions & 3 deletions align/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def load_catalog():
elif CLI_ARGS.catalog:
catalog = check_path(CLI_ARGS.catalog)
catalog_dir = path.dirname(catalog)
with open(catalog, 'r') as catalog_file:
with open(catalog, 'r', encoding='utf-8') as catalog_file:
catalog_file_entries = json.load(catalog_file)
for entry in progress(catalog_file_entries, desc='Reading catalog'):
audio = make_absolute(catalog_dir, entry['audio'])
Expand Down Expand Up @@ -146,7 +146,7 @@ def get_meta_list(ae, mf):
reasons = Counter()
for catalog_index, catalog_entry in enumerate(progress(catalog_entries, desc='Loading alignments')):
audio_path, aligned_path = catalog_entry
with open(aligned_path, 'r') as aligned_file:
with open(aligned_path, 'r', encoding='utf-8') as aligned_file:
aligned = json.load(aligned_file)
for alignment_index, alignment in enumerate(aligned):
quality = eval(CLI_ARGS.criteria, {'math': math}, alignment)
Expand Down Expand Up @@ -443,7 +443,7 @@ def parse_args():
def load_sample(entry):
catalog_index, catalog_entry = entry
audio_path, aligned_path = catalog_entry
with open(aligned_path, 'r') as aligned_file:
with open(aligned_path, 'r', encoding='utf-8') as aligned_file:
aligned = json.load(aligned_file)
tries = 2
while tries > 0:
Expand Down
125 changes: 125 additions & 0 deletions align/generate_lm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import gzip
import io
import os
import subprocess
from collections import Counter

def convert_and_filter_topk(output_dir, input_txt, top_k):
""" Convert to lowercase, count word occurrences and save top-k words to a file """

counter = Counter()
data_lower = output_dir + "." + "lower.txt.gz"

print("\nConverting to lowercase and counting word occurrences ...")
with io.TextIOWrapper(
io.BufferedWriter(gzip.open(data_lower, "w+")), encoding="utf-8"
) as file_out:

# Open the input file either from input.txt or input.txt.gz
_, file_extension = os.path.splitext(input_txt)
if file_extension == ".gz":
file_in = io.TextIOWrapper(
io.BufferedReader(gzip.open(input_txt)), encoding="utf-8"
)
else:
file_in = open(input_txt, encoding="utf-8")

for line in file_in:
line_lower = line.lower()
counter.update(line_lower.split())
file_out.write(line_lower)

file_in.close()

# Save top-k words
print("\nSaving top {} words ...".format(top_k))
top_counter = counter.most_common(top_k)
vocab_str = "\n".join(word for word, count in top_counter)
vocab_path = "vocab-{}.txt".format(top_k)
vocab_path = output_dir + "." + vocab_path
with open(vocab_path, "w+", encoding="utf-8") as file:
file.write(vocab_str)

print("\nCalculating word statistics ...")
total_words = sum(counter.values())
print(" Your text file has {} words in total".format(total_words))
print(" It has {} unique words".format(len(counter)))
top_words_sum = sum(count for word, count in top_counter)
word_fraction = (top_words_sum / total_words) * 100
print(
" Your top-{} words are {:.4f} percent of all words".format(
top_k, word_fraction
)
)
print(' Your most common word "{}" occurred {} times'.format(*top_counter[0]))
last_word, last_count = top_counter[-1]
print(
' The least common word in your top-k is "{}" with {} times'.format(
last_word, last_count
)
)
for i, (w, c) in enumerate(reversed(top_counter)):
if c > last_count:
print(
' The first word with {} occurrences is "{}" at place {}'.format(
c, w, len(top_counter) - 1 - i
)
)
break

return data_lower, vocab_str


def build_lm(output_dir, kenlm_bins, arpa_order, max_arpa_memory, arpa_prune, discount_fallback, binary_a_bits, binary_q_bits, binary_type, data_lower, vocab_str):
print("\nCreating ARPA file ...")
lm_path = output_dir + "." + "lm.arpa"
subargs = [
os.path.join(kenlm_bins, "lmplz"),
"--order",
str(arpa_order),
"--temp_prefix",
output_dir,
"--memory",
max_arpa_memory,
"--text",
data_lower,
"--arpa",
lm_path,
"--prune",
*arpa_prune.split("|"),
]
if discount_fallback:
subargs += ["--discount_fallback"]
subprocess.check_call(subargs)

# Filter LM using vocabulary of top-k words
print("\nFiltering ARPA file using vocabulary of top-k words ...")
filtered_path = output_dir + "." + "lm_filtered.arpa"
subprocess.run(
[
os.path.join(kenlm_bins, "filter"),
"single",
"model:{}".format(lm_path),
filtered_path,
],
input=vocab_str.encode("utf-8"),
check=True,
)

# Quantize and produce trie binary.
print("\nBuilding lm.binary ...")
binary_path = output_dir + "." + "lm.binary"
subprocess.check_call(
[
os.path.join(kenlm_bins, "build_binary"),
"-s",
"-a",
str(binary_a_bits),
"-q",
str(binary_q_bits),
"-v",
binary_type,
filtered_path,
binary_path,
]
)
Loading