Skip to content

Commit

Permalink
Add input array
Browse files Browse the repository at this point in the history
  • Loading branch information
wannaphong committed Mar 19, 2023
1 parent e1671c1 commit eea5aed
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 11 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,14 @@ print(asr(file))
### API

```python
asr(file: str, model: str = _model_name, lm: bool=False, device: str=None)
asr(data: str, model: str = _model_name, lm: bool=False, device: str=None, sampling_rate: int=16_000)
```

- file: path of sound file
- data: path of sound file or numpy array of the voice
- model: The ASR model
- lm: Use language model (except *airesearch/wav2vec2-large-xlsr-53-th* model)
- device: device
- sampling_rate: The sample rate
- return: thai text from ASR

**Options for model**
Expand Down
26 changes: 17 additions & 9 deletions pythaiasr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import logging
from transformers.utils import logging
logging.set_verbosity(40)
import numpy as np


class ASR:
Expand Down Expand Up @@ -57,14 +58,20 @@ def prepare_dataset(self, batch: dict) -> dict:
batch["input_values"] = self.processor(batch["speech"], sampling_rate=batch["sampling_rate"]).input_values
return batch

def __call__(self, file: str) -> str:
def __call__(self, data: str, sampling_rate: int=16_000) -> str:
"""
:param str file: path of sound file
:param str model: The ASR model
:param str data: path of sound file or numpy array of the voice
:param int sampling_rate: The sample rate
"""
b = {}
b['path'] = file
a = self.prepare_dataset(self.resample(self.speech_file_to_array_fn(b)))
if isinstance(data,np.ndarray):
b["speech"] = data
b["sampling_rate"] = sampling_rate
_preprocessing = b
else:
b["path"] = data
_preprocessing = self.speech_file_to_array_fn(b)
a = self.prepare_dataset(b)
input_dict = self.processor(a["input_values"][0], return_tensors="pt", padding=True).to(self.device)
logits = self.model(input_dict.input_values).logits
pred_ids = torch.argmax(logits, dim=-1)[0]
Expand All @@ -80,13 +87,14 @@ def __call__(self, file: str) -> str:
_model = None


def asr(file: str, model: str = _model_name, lm: bool=False, device: str=None) -> str:
def asr(data: str, model: str = _model_name, lm: bool=False, device: str=None, sampling_rate: int=16_000) -> str:
"""
:param str file: path of sound file
:param str data: path of sound file or numpy array of the voice
:param str model: The ASR model name
:param bool lm: Use language model (except *airesearch/wav2vec2-large-xlsr-53-th* model)
:param str device: device
:return: thai text from ASR
:param int sampling_rate: The sample rate
:return: Thai text from ASR
:rtype: str
**Options for model**
Expand All @@ -99,4 +107,4 @@ def asr(file: str, model: str = _model_name, lm: bool=False, device: str=None) -
_model = ASR(model, lm=lm, device=device)
_model_name = model

return _model(file=file)
return _model(data=data, sampling_rate=sampling_rate)
15 changes: 15 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# -*- coding: utf-8 -*-
"""
Unit test.
Each file in tests/ is for each main package.
"""
import sys
import unittest

sys.path.append("../pythaiasr")

loader = unittest.TestLoader()
testSuite = loader.discover("tests")
testRunner = unittest.TextTestRunner(verbosity=1)
testRunner.run(testSuite)
Binary file added tests/common_voice_th_25686161.wav
Binary file not shown.
15 changes: 15 additions & 0 deletions tests/test_asr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# -*- coding: utf-8 -*-

import unittest
import torchaudio
from pythaiasr import asr
import os

file = os.path.join(".", "tests", "common_voice_th_25686161.wav")

class TestKhaveePackage(unittest.TestCase):
def test_asr(self):
self.assertIsNotNone(asr(file, device="cpu"))
def test_asr_array(self):
speech_array, sampling_rate = torchaudio.load(file)
self.assertIsNotNone(asr(speech_array[0].numpy(), device="cpu", sampling_rate=sampling_rate))

0 comments on commit eea5aed

Please sign in to comment.