diff --git a/setup.py b/setup.py index 1d539db..ac36801 100644 --- a/setup.py +++ b/setup.py @@ -29,7 +29,8 @@ "sounddevice", "soundfile", "torch", - "transformers" + "transformers", + "speechbrain" ], "service": [ "cltl.backend", diff --git a/src/cltl/asr/speechbrain_asr.py b/src/cltl/asr/speechbrain_asr.py new file mode 100644 index 0000000..28bfb34 --- /dev/null +++ b/src/cltl/asr/speechbrain_asr.py @@ -0,0 +1,30 @@ +import os +import shutil +import tempfile +import time + +import numpy as np +from speechbrain.pretrained import EncoderDecoderASR + +from cltl.asr.api import ASR +from cltl.asr.util import store_wav + + +class SpeechbrainASR(ASR): + def __init__(self, model_id: str, storage: str = None, model_dir: str = None): + self.processor = EncoderDecoderASR.from_hparams(source=model_id, savedir=model_dir) + self._storage = storage if storage else tempfile.mkdtemp() + self._clean_storage = storage is None + + def clean(self): + shutil.rmtree(self._storage) + + def speech_to_text(self, audio: np.array, sampling_rate: int) -> str: + wav_file = str(os.path.join(self._storage, f"asr-{time.time()}.wav")) + try: + store_wav(audio, sampling_rate, wav_file) + + return self.processor.transcribe_file(wav_file) + finally: + if self._clean_storage: + os.remove(wav_file) diff --git a/tests/test_speechbrain_asr.py b/tests/test_speechbrain_asr.py new file mode 100644 index 0000000..1805768 --- /dev/null +++ b/tests/test_speechbrain_asr.py @@ -0,0 +1,46 @@ +import shutil +import tempfile +import unittest + +import numpy as np +import soundfile as sf +from importlib_resources import path + +from cltl.asr.speechbrain_asr import SpeechbrainASR + + +class TestSpeechbrainASR(unittest.TestCase): + @classmethod + def setUpClass(cls) -> None: + cls.tempdir = tempfile.mkdtemp() + cls.asr = SpeechbrainASR("speechbrain/asr-transformer-transformerlm-librispeech", cls.tempdir, cls.tempdir) + + @classmethod + def tearDownClass(cls) -> None: + shutil.rmtree(cls.tempdir) + del cls.asr + + def test_resampling_stereo_to_mono(self): + resampled = self.asr._resample(np.stack((np.full((16,), 101, dtype=np.int16), np.zeros((16,), dtype=np.int16)), axis=1), 16000) + self.assertEqual(resampled.shape, (16,)) + self.assertTrue(all(resampled == 50)) + + def test_resampling_stereo_to_mono_max_volume(self): + resampled = self.asr._resample(np.full((16,), 32767, dtype=np.int16), 16000) + self.assertEqual(resampled.shape, (16,)) + self.assertTrue(all(resampled == 32767)) + + def test_resampling_single_channel_is_squeezed(self): + resampled = self.asr._resample(np.ones((16, 1), dtype=np.int16), 16000) + self.assertEqual(resampled.shape, (16,)) + + def test_resampling_single_channel(self): + resampled = self.asr._resample(np.ones((16,), dtype=np.int16), 16000) + self.assertEqual(resampled.shape, (16,)) + + def test_speech_to_text(self): + with path("resources", "test.wav") as wav: + speech_array, sampling_rate = sf.read(wav, dtype=np.int16) + + transcript = self.asr.speech_to_text(speech_array, sampling_rate) + self.assertEqual("IT'S HEALTHIER TO COOK WITHOUT SUGAR", transcript.upper()) \ No newline at end of file