Skip to content

Commit

Permalink
Add device
Browse files Browse the repository at this point in the history
  • Loading branch information
wannaphong committed Mar 10, 2023
1 parent c12fb9b commit e1671c1
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions pythaiasr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
from transformers.utils import logging
logging.set_verbosity(40)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


class ASR:
def __init__(self, model: str="airesearch/wav2vec2-large-xlsr-53-th", lm: bool=False, device: str=None) -> None:
Expand All @@ -28,17 +26,19 @@ def __init__(self, model: str="airesearch/wav2vec2-large-xlsr-53-th", lm: bool=F
"wannaphong/wav2vec2-large-xlsr-53-th-cv8-deepcut"
]
assert self.model_name in self.support_model
self.lm =lm
self.lm = lm
if device!=None:
self.device = torch.device(device)
else:
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if not self.lm:
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
self.processor = Wav2Vec2Processor.from_pretrained(self.model_name)
self.model = Wav2Vec2ForCTC.from_pretrained(self.model_name)
self.model = Wav2Vec2ForCTC.from_pretrained(self.model_name).to(self.device)
else:
from transformers import AutoProcessor, AutoModelForCTC
self.processor = AutoProcessor.from_pretrained(self.model_name)
self.model = AutoModelForCTC.from_pretrained(self.model_name)
if device!=None:
self.device = torch.device(device)
self.model = AutoModelForCTC.from_pretrained(self.model_name).to(self.device)

def speech_file_to_array_fn(self, batch: dict) -> dict:
speech_array, sampling_rate = torchaudio.load(batch["path"])
Expand All @@ -65,7 +65,7 @@ def __call__(self, file: str) -> str:
b = {}
b['path'] = file
a = self.prepare_dataset(self.resample(self.speech_file_to_array_fn(b)))
input_dict = self.processor(a["input_values"][0], return_tensors="pt", padding=True)
input_dict = self.processor(a["input_values"][0], return_tensors="pt", padding=True).to(self.device)
logits = self.model(input_dict.input_values).logits
pred_ids = torch.argmax(logits, dim=-1)[0]
if self.model_name == "airesearch/wav2vec2-large-xlsr-53-th":
Expand Down

0 comments on commit e1671c1

Please sign in to comment.