diff --git a/README.md b/README.md index d572283..8c9d3eb 100644 --- a/README.md +++ b/README.md @@ -37,13 +37,14 @@ print(asr(file)) ### API ```python -asr(file: str, model: str = _model_name, lm: bool=False, device: str=None) +asr(data: str, model: str = _model_name, lm: bool=False, device: str=None, sampling_rate: int=16_000) ``` -- file: path of sound file +- data: path of sound file or numpy array of the voice - model: The ASR model - lm: Use language model (except *airesearch/wav2vec2-large-xlsr-53-th* model) - device: device +- sampling_rate: The sample rate - return: thai text from ASR **Options for model** diff --git a/pythaiasr/__init__.py b/pythaiasr/__init__.py index 3330cb5..7aa8e19 100644 --- a/pythaiasr/__init__.py +++ b/pythaiasr/__init__.py @@ -5,6 +5,7 @@ import logging from transformers.utils import logging logging.set_verbosity(40) +import numpy as np class ASR: @@ -57,14 +58,20 @@ def prepare_dataset(self, batch: dict) -> dict: batch["input_values"] = self.processor(batch["speech"], sampling_rate=batch["sampling_rate"]).input_values return batch - def __call__(self, file: str) -> str: + def __call__(self, data: str, sampling_rate: int=16_000) -> str: """ - :param str file: path of sound file - :param str model: The ASR model + :param str data: path of sound file or numpy array of the voice + :param int sampling_rate: The sample rate """ b = {} - b['path'] = file - a = self.prepare_dataset(self.resample(self.speech_file_to_array_fn(b))) + if isinstance(data,np.ndarray): + b["speech"] = data + b["sampling_rate"] = sampling_rate + _preprocessing = b + else: + b["path"] = data + _preprocessing = self.speech_file_to_array_fn(b) + a = self.prepare_dataset(b) input_dict = self.processor(a["input_values"][0], return_tensors="pt", padding=True).to(self.device) logits = self.model(input_dict.input_values).logits pred_ids = torch.argmax(logits, dim=-1)[0] @@ -80,13 +87,14 @@ def __call__(self, file: str) -> str: _model = None -def asr(file: str, model: str = _model_name, lm: bool=False, device: str=None) -> str: +def asr(data: str, model: str = _model_name, lm: bool=False, device: str=None, sampling_rate: int=16_000) -> str: """ - :param str file: path of sound file + :param str data: path of sound file or numpy array of the voice :param str model: The ASR model name :param bool lm: Use language model (except *airesearch/wav2vec2-large-xlsr-53-th* model) :param str device: device - :return: thai text from ASR + :param int sampling_rate: The sample rate + :return: Thai text from ASR :rtype: str **Options for model** @@ -99,4 +107,4 @@ def asr(file: str, model: str = _model_name, lm: bool=False, device: str=None) - _model = ASR(model, lm=lm, device=device) _model_name = model - return _model(file=file) + return _model(data=data, sampling_rate=sampling_rate) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..d2af70e --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1,15 @@ +# -*- coding: utf-8 -*- +""" +Unit test. + +Each file in tests/ is for each main package. +""" +import sys +import unittest + +sys.path.append("../pythaiasr") + +loader = unittest.TestLoader() +testSuite = loader.discover("tests") +testRunner = unittest.TextTestRunner(verbosity=1) +testRunner.run(testSuite) diff --git a/tests/common_voice_th_25686161.wav b/tests/common_voice_th_25686161.wav new file mode 100644 index 0000000..cdc26bd Binary files /dev/null and b/tests/common_voice_th_25686161.wav differ diff --git a/tests/test_asr.py b/tests/test_asr.py new file mode 100644 index 0000000..821592e --- /dev/null +++ b/tests/test_asr.py @@ -0,0 +1,15 @@ +# -*- coding: utf-8 -*- + +import unittest +import torchaudio +from pythaiasr import asr +import os + +file = os.path.join(".", "tests", "common_voice_th_25686161.wav") + +class TestKhaveePackage(unittest.TestCase): + def test_asr(self): + self.assertIsNotNone(asr(file, device="cpu")) + def test_asr_array(self): + speech_array, sampling_rate = torchaudio.load(file) + self.assertIsNotNone(asr(speech_array[0].numpy(), device="cpu", sampling_rate=sampling_rate))