Skip to content

Commit

Permalink
Add speechbrain based ASR implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
numblr committed Nov 15, 2021
1 parent 7bfd84a commit a7b1690
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 1 deletion.
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@
"sounddevice",
"soundfile",
"torch",
"transformers"
"transformers",
"speechbrain"
],
"service": [
"cltl.backend",
Expand Down
30 changes: 30 additions & 0 deletions src/cltl/asr/speechbrain_asr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import os
import shutil
import tempfile
import time

import numpy as np
from speechbrain.pretrained import EncoderDecoderASR

from cltl.asr.api import ASR
from cltl.asr.util import store_wav


class SpeechbrainASR(ASR):
def __init__(self, model_id: str, storage: str = None, model_dir: str = None):
self.processor = EncoderDecoderASR.from_hparams(source=model_id, savedir=model_dir)
self._storage = storage if storage else tempfile.mkdtemp()
self._clean_storage = storage is None

def clean(self):
shutil.rmtree(self._storage)

def speech_to_text(self, audio: np.array, sampling_rate: int) -> str:
wav_file = str(os.path.join(self._storage, f"asr-{time.time()}.wav"))
try:
store_wav(audio, sampling_rate, wav_file)

return self.processor.transcribe_file(wav_file)
finally:
if self._clean_storage:
os.remove(wav_file)
46 changes: 46 additions & 0 deletions tests/test_speechbrain_asr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import shutil
import tempfile
import unittest

import numpy as np
import soundfile as sf
from importlib_resources import path

from cltl.asr.speechbrain_asr import SpeechbrainASR


class TestSpeechbrainASR(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
cls.tempdir = tempfile.mkdtemp()
cls.asr = SpeechbrainASR("speechbrain/asr-transformer-transformerlm-librispeech", cls.tempdir, cls.tempdir)

@classmethod
def tearDownClass(cls) -> None:
shutil.rmtree(cls.tempdir)
del cls.asr

def test_resampling_stereo_to_mono(self):
resampled = self.asr._resample(np.stack((np.full((16,), 101, dtype=np.int16), np.zeros((16,), dtype=np.int16)), axis=1), 16000)
self.assertEqual(resampled.shape, (16,))
self.assertTrue(all(resampled == 50))

def test_resampling_stereo_to_mono_max_volume(self):
resampled = self.asr._resample(np.full((16,), 32767, dtype=np.int16), 16000)
self.assertEqual(resampled.shape, (16,))
self.assertTrue(all(resampled == 32767))

def test_resampling_single_channel_is_squeezed(self):
resampled = self.asr._resample(np.ones((16, 1), dtype=np.int16), 16000)
self.assertEqual(resampled.shape, (16,))

def test_resampling_single_channel(self):
resampled = self.asr._resample(np.ones((16,), dtype=np.int16), 16000)
self.assertEqual(resampled.shape, (16,))

def test_speech_to_text(self):
with path("resources", "test.wav") as wav:
speech_array, sampling_rate = sf.read(wav, dtype=np.int16)

transcript = self.asr.speech_to_text(speech_array, sampling_rate)
self.assertEqual("IT'S HEALTHIER TO COOK WITHOUT SUGAR", transcript.upper())

0 comments on commit a7b1690

Please sign in to comment.