Skip to content

Commit

Permalink
Merge pull request #82 from TensorSpeech/dev/testing
Browse files Browse the repository at this point in the history
Update batch for faster testing
  • Loading branch information
nglehuy authored Dec 19, 2020
2 parents 46edde8 + 288584a commit 6d70eab
Show file tree
Hide file tree
Showing 15 changed files with 250 additions and 139 deletions.
6 changes: 3 additions & 3 deletions examples/conformer/masking/masking.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import tensorflow as tf
from tensorflow_asr.utils.utils import shape_list
from tensorflow_asr.utils.utils import shape_list, get_reduced_length


def create_padding_mask(features, input_length, time_reduction_factor):
Expand All @@ -14,10 +14,10 @@ def create_padding_mask(features, input_length, time_reduction_factor):
[tf.Tensor]: with shape [B, Tquery, Tkey]
"""
batch_size, padded_time, _, _ = shape_list(features)
reduced_padded_time = tf.math.ceil(padded_time / time_reduction_factor)
reduced_padded_time = get_reduced_length(padded_time, time_reduction_factor)

def create_mask(length):
reduced_length = tf.math.ceil(length / time_reduction_factor)
reduced_length = get_reduced_length(length, time_reduction_factor)
mask = tf.ones([reduced_length, reduced_length], dtype=tf.float32)
return tf.pad(
mask,
Expand Down
5 changes: 3 additions & 2 deletions examples/conformer/masking/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from masking import create_padding_mask
from tensorflow_asr.runners.transducer_runners import TransducerTrainer, TransducerTrainerGA
from tensorflow_asr.losses.rnnt_losses import rnnt_loss
from tensorflow_asr.utils.utils import get_reduced_length


class TrainerWithMasking(TransducerTrainer):
Expand All @@ -17,7 +18,7 @@ def _train_step(self, batch):
tape.watch(logits)
per_train_loss = rnnt_loss(
logits=logits, labels=labels, label_length=label_length,
logit_length=tf.cast(tf.math.ceil(input_length / self.model.time_reduction_factor), dtype=tf.int32),
logit_length=get_reduced_length(input_length, self.model.time_reduction_factor),
blank=self.text_featurizer.blank
)
train_loss = tf.nn.compute_average_loss(per_train_loss,
Expand All @@ -41,7 +42,7 @@ def _train_step(self, batch):
tape.watch(logits)
per_train_loss = rnnt_loss(
logits=logits, labels=labels, label_length=label_length,
logit_length=tf.cast(tf.math.ceil(input_length / self.model.time_reduction_factor), dtype=tf.int32),
logit_length=get_reduced_length(input_length, self.model.time_reduction_factor),
blank=self.text_featurizer.blank
)
train_loss = tf.nn.compute_average_loss(
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@

setuptools.setup(
name="TensorFlowASR",
version="0.5.0",
version="0.5.1",
author="Huy Le Nguyen",
author_email="nlhuy.cs.16@gmail.com",
description="Almost State-of-the-art Automatic Speech Recognition using Tensorflow 2",
Expand Down
2 changes: 1 addition & 1 deletion tensorflow_asr/datasets/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,5 +53,5 @@ Where `prediction` and `prediction_length` are the label prepanded by blank and
**Outputs when iterating in test step**

```python
(path, signals, labels)
(path, features, input_lengths, labels)
```
32 changes: 24 additions & 8 deletions tensorflow_asr/datasets/asr_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,8 +241,15 @@ class ASRTFRecordTestDataset(ASRTFRecordDataset):
def preprocess(self, path, transcript):
with tf.device("/CPU:0"):
signal = read_raw_audio(path.decode("utf-8"), self.speech_featurizer.sample_rate)

features = self.speech_featurizer.extract(signal)
features = tf.convert_to_tensor(features, tf.float32)
input_length = tf.cast(tf.shape(features)[0], tf.int32)

label = self.text_featurizer.extract(transcript.decode("utf-8"))
return path, signal, tf.convert_to_tensor(label, dtype=tf.int32)
label = tf.convert_to_tensor(label, dtype=tf.int32)

return path, features, input_length, label

@tf.function
def parse(self, record):
Expand All @@ -256,7 +263,7 @@ def parse(self, record):
return tf.numpy_function(
self.preprocess,
inp=[example["audio"], example["transcript"]],
Tout=(tf.string, tf.float32, tf.int32)
Tout=(tf.string, tf.float32, tf.int32, tf.int32)
)

def process(self, dataset, batch_size):
Expand All @@ -273,10 +280,11 @@ def process(self, dataset, batch_size):
batch_size=batch_size,
padded_shapes=(
tf.TensorShape([]),
tf.TensorShape([None]),
tf.TensorShape(self.speech_featurizer.shape),
tf.TensorShape([]),
tf.TensorShape([None]),
),
padding_values=("", 0.0, self.text_featurizer.blank),
padding_values=("", 0.0, 0, self.text_featurizer.blank),
drop_remainder=True
)

Expand Down Expand Up @@ -304,15 +312,22 @@ class ASRSliceTestDataset(ASRDataset):
def preprocess(self, path, transcript):
with tf.device("/CPU:0"):
signal = read_raw_audio(path.decode("utf-8"), self.speech_featurizer.sample_rate)

features = self.speech_featurizer.extract(signal)
features = tf.convert_to_tensor(features, tf.float32)
input_length = tf.cast(tf.shape(features)[0], tf.int32)

label = self.text_featurizer.extract(transcript.decode("utf-8"))
return path, signal, tf.convert_to_tensor(label, dtype=tf.int32)
label = tf.convert_to_tensor(label, dtype=tf.int32)

return path, features, input_length, label

@tf.function
def parse(self, record):
return tf.numpy_function(
self.preprocess,
inp=[record[0], record[1]],
Tout=[tf.string, tf.float32, tf.int32]
Tout=[tf.string, tf.float32, tf.int32, tf.int32]
)

def process(self, dataset, batch_size):
Expand All @@ -329,10 +344,11 @@ def process(self, dataset, batch_size):
batch_size=batch_size,
padded_shapes=(
tf.TensorShape([]),
tf.TensorShape([None]),
tf.TensorShape(self.speech_featurizer.shape),
tf.TensorShape([]),
tf.TensorShape([None]),
),
padding_values=("", 0.0, self.text_featurizer.blank),
padding_values=("", 0.0, 0, self.text_featurizer.blank),
drop_remainder=True
)

Expand Down
8 changes: 8 additions & 0 deletions tensorflow_asr/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,11 @@ def _build(self, *args, **kwargs):
@abc.abstractmethod
def call(self, inputs, training=False, **kwargs):
raise NotImplementedError()

@abc.abstractmethod
def recognize(self, features, input_lengths, **kwargs):
pass

@abc.abstractmethod
def recognize_beam(self, features, input_lengths, **kwargs):
pass
11 changes: 7 additions & 4 deletions tensorflow_asr/models/contextnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
""" Ref: https://github.com/iankur/ContextNet """

from typing import List
from typing import List, Optional
import tensorflow as tf
from .transducer import Transducer
from ..utils.utils import merge_two_last_dims, get_reduced_length
Expand Down Expand Up @@ -234,8 +234,7 @@ def __init__(self,
)
self.dmodel = self.encoder.blocks[-1].dmodel
self.time_reduction_factor = 1
for block in self.encoder.blocks:
self.time_reduction_factor *= block.time_reduction_factor
for block in self.encoder.blocks: self.time_reduction_factor *= block.time_reduction_factor

def call(self, inputs, training=False, **kwargs):
features, input_length, prediction, prediction_length = inputs
Expand All @@ -244,8 +243,12 @@ def call(self, inputs, training=False, **kwargs):
outputs = self.joint_net([enc, pred], training=training, **kwargs)
return outputs

def encoder_inference(self, features):
def encoder_inference(self,
features: tf.Tensor,
input_length: Optional[tf.Tensor] = None,
with_batch: bool = False):
with tf.name_scope(f"{self.name}_encoder"):
if with_batch: return self.encoder([features, input_length], training=False)
input_length = tf.expand_dims(tf.shape(features)[0], axis=0)
outputs = tf.expand_dims(features, axis=0)
outputs = self.encoder([outputs, input_length], training=False)
Expand Down
33 changes: 12 additions & 21 deletions tensorflow_asr/models/ctc.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional
import numpy as np
import tensorflow as tf

from . import Model
from ..featurizers.speech_featurizers import TFSpeechFeaturizer
from ..featurizers.text_featurizers import TextFeaturizer
from ..utils.utils import shape_list
from ..utils.utils import shape_list, get_reduced_length


class CtcModel(Model):
Expand All @@ -41,20 +42,15 @@ def call(self, inputs, training=False, **kwargs):
# -------------------------------- GREEDY -------------------------------------

@tf.function
def recognize(self, signals):

def extract_fn(signal): return self.speech_featurizer.tf_extract(signal)

features = tf.map_fn(extract_fn, signals,
fn_output_signature=tf.TensorSpec(self.speech_featurizer.shape, dtype=tf.float32))
def recognize(self, features: tf.Tensor, input_length: Optional[tf.Tensor]):
logits = self(features, training=False)
probs = tf.nn.softmax(logits)

def map_fn(prob): return tf.numpy_function(self.perform_greedy, inp=[prob], Tout=tf.string)
def map_fn(prob): return tf.numpy_function(self.__perform_greedy, inp=[prob], Tout=tf.string)

return tf.map_fn(map_fn, probs, fn_output_signature=tf.TensorSpec([], dtype=tf.string))

def perform_greedy(self, probs: np.ndarray):
def __perform_greedy(self, probs: np.ndarray):
from ctc_decoders import ctc_greedy_decoder
decoded = ctc_greedy_decoder(probs, vocabulary=self.text_featurizer.vocab_array)
return tf.convert_to_tensor(decoded, dtype=tf.string)
Expand All @@ -71,7 +67,7 @@ def recognize_tflite(self, signal):
features = self.speech_featurizer.tf_extract(signal)
features = tf.expand_dims(features, axis=0)
input_length = shape_list(features)[1]
input_length = input_length // self.base_model.time_reduction_factor
input_length = get_reduced_length(input_length, self.base_model.time_reduction_factor)
input_length = tf.expand_dims(input_length, axis=0)
logits = self(features, training=False)
probs = tf.nn.softmax(logits)
Expand All @@ -85,25 +81,20 @@ def recognize_tflite(self, signal):
# -------------------------------- BEAM SEARCH -------------------------------------

@tf.function
def recognize_beam(self, signals, lm=False):

def extract_fn(signal): return self.speech_featurizer.tf_extract(signal)

features = tf.map_fn(extract_fn, signals,
fn_output_signature=tf.TensorSpec(self.speech_featurizer.shape, dtype=tf.float32))
def recognize_beam(self, features: tf.Tensor, input_length: Optional[tf.Tensor], lm: bool = False):
logits = self(features, training=False)
probs = tf.nn.softmax(logits)

def map_fn(prob): return tf.numpy_function(self.perform_beam_search, inp=[prob, lm], Tout=tf.string)
def map_fn(prob): return tf.numpy_function(self.__perform_beam_search, inp=[prob, lm], Tout=tf.string)

return tf.map_fn(map_fn, probs, dtype=tf.string)

def perform_beam_search(self, probs: np.ndarray, lm: bool = False):
def __perform_beam_search(self, probs: np.ndarray, lm: bool = False):
from ctc_decoders import ctc_beam_search_decoder
decoded = ctc_beam_search_decoder(
probs_seq=probs,
vocabulary=self.text_featurizer.vocab_array,
beam_size=self.text_featurizer.decoder_config["beam_width"],
beam_size=self.text_featurizer.decoder_config.beam_width,
ext_scoring_func=self.text_featurizer.scorer if lm else None
)
decoded = decoded[0][-1]
Expand All @@ -122,13 +113,13 @@ def recognize_beam_tflite(self, signal):
features = self.speech_featurizer.tf_extract(signal)
features = tf.expand_dims(features, axis=0)
input_length = shape_list(features)[1]
input_length = input_length // self.base_model.time_reduction_factor
input_length = get_reduced_length(input_length, self.base_model.time_reduction_factor)
input_length = tf.expand_dims(input_length, axis=0)
logits = self(features, training=False)
probs = tf.nn.softmax(logits)
decoded = tf.keras.backend.ctc_decode(
y_pred=probs, input_length=input_length, greedy=False,
beam_width=self.text_featurizer.decoder_config["beam_width"]
beam_width=self.text_featurizer.decoder_config.beam_width
)
decoded = tf.cast(decoded[0][0][0], dtype=tf.int32)
transcript = self.text_featurizer.indices2upoints(decoded)
Expand Down
Loading

0 comments on commit 6d70eab

Please sign in to comment.