-
Notifications
You must be signed in to change notification settings - Fork 240
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add tts_tool, for converting a HF dataset to audio (#12)
* v1 * v2 * v1 * Polish up * Delete bq.py * args cleanup * bugfixes * Update Justfile * simple_parsing * concurrency
- Loading branch information
Showing
3 changed files
with
136 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
import io | ||
import os | ||
from typing import Optional | ||
from xml.sax import saxutils | ||
|
||
import numpy as np | ||
import requests | ||
import soundfile as sf | ||
|
||
|
||
def _make_ssml(voice: str, text: str): | ||
return f""" | ||
<speak version="1.0" xml:lang="en-US"> | ||
<voice xml:lang="en-US" name="{voice}"> | ||
{saxutils.escape(text)} | ||
</voice> | ||
</speak>""" | ||
|
||
|
||
class AzureTts: | ||
DEFAULT_VOICE = "en-US-JennyNeural" | ||
|
||
def __init__(self, voice: Optional[str] = None, sample_rate: int = 16000): | ||
self._session = requests.Session() | ||
self._voice = voice or self.DEFAULT_VOICE | ||
self._sample_rate = sample_rate | ||
|
||
def tts(self, text: str): | ||
region = "westus" | ||
api_key = os.environ.get("AZURE_TTS_API_KEY") or os.environ.get( | ||
"AZURE_WESTUS_TTS_API_KEY" | ||
) | ||
output_format = f"raw-{self._sample_rate // 1000}khz-16bit-mono-pcm" | ||
url = f"https://{region}.tts.speech.microsoft.com/cognitiveservices/v1" | ||
headers = { | ||
"Ocp-Apim-Subscription-Key": api_key, | ||
"Content-Type": "application/ssml+xml", | ||
"X-Microsoft-OutputFormat": output_format, | ||
"User-Agent": "MyTTS", | ||
} | ||
body = _make_ssml(self._voice, text) | ||
response = self._session.post(url, headers=headers, data=body) | ||
response.raise_for_status() | ||
|
||
pcm_array = np.frombuffer(response.content, dtype=np.int16) | ||
wav_bytes = io.BytesIO() | ||
sf.write(wav_bytes, pcm_array, self._sample_rate, format="wav") | ||
return wav_bytes.getvalue() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
import dataclasses | ||
import os | ||
from concurrent import futures | ||
from typing import Dict, Optional, Union | ||
|
||
import datasets | ||
import simple_parsing | ||
|
||
from ultravox.tools import tts | ||
|
||
|
||
# This script is used to generate audio samples from text using a TTS model. | ||
# Ex: just tts -d google/boolq -c question -a audio -u fixie-ai/boolq-audio | ||
@dataclasses.dataclass | ||
class TtsArgs: | ||
dataset_name: str = simple_parsing.field(alias="-d") | ||
dataset_subset: str = simple_parsing.field(default="default", alias="-S") | ||
dataset_split: Optional[str] = simple_parsing.field(default=None, alias="-s") | ||
column_name: str = simple_parsing.field(default="question", alias="-c") | ||
audio_column_name: Optional[str] = simple_parsing.field(default=None, alias="-a") | ||
num_samples: Optional[int] = simple_parsing.field(default=None, alias="-n") | ||
num_workers: int = simple_parsing.field(default=16, alias="-w") | ||
voice: Optional[str] = simple_parsing.field(default=None, alias="-V") | ||
sample_rate: int = simple_parsing.field(default=16000, alias="-r") | ||
upload_name: Optional[str] = simple_parsing.field(default=None, alias="-u") | ||
token: Optional[str] = simple_parsing.field(default=None, alias="-t") | ||
|
||
|
||
def _tts_split( | ||
tts_client: tts.AzureTts, | ||
ds_split: datasets.IterableDataset, | ||
col_name: str, | ||
audio_col_name: str, | ||
num_workers: int, | ||
): | ||
def get_text(val: Union[str, Dict[str, str]]) -> str: | ||
return val["text"] if isinstance(val, dict) else val | ||
|
||
def tts_batch(batch): | ||
with futures.ThreadPoolExecutor(max_workers=num_workers) as executor: | ||
texts = [get_text(val) for val in batch[col_name]] | ||
audio_futures = [executor.submit(tts_client.tts, text) for text in texts] | ||
batch[audio_col_name] = [f.result() for f in audio_futures] | ||
return batch | ||
|
||
return ds_split.map(tts_batch, batched=True).cast_column( | ||
audio_col_name, datasets.Audio(sampling_rate=tts_client._sample_rate) | ||
) | ||
|
||
|
||
def main(args: TtsArgs): | ||
ds_name = args.dataset_name | ||
col_name = args.column_name | ||
audio_col_name = args.audio_column_name or f"{col_name}_audio" | ||
tts_client = tts.AzureTts(voice=args.voice, sample_rate=args.sample_rate) | ||
|
||
print(f'Loading dataset "{ds_name}", mapping "{col_name}" to "{audio_col_name}"...') | ||
data_dict = datasets.load_dataset( | ||
ds_name, args.dataset_subset, split=args.dataset_split | ||
) | ||
if args.dataset_split: | ||
data_dict = {args.dataset_split: data_dict} | ||
for split, ds_split in data_dict.items(): | ||
print(f'Processing split "{split}"...') | ||
if args.num_samples: | ||
ds_split = ds_split.select(range(args.num_samples)) | ||
new_split = _tts_split( | ||
tts_client, ds_split, col_name, audio_col_name, args.num_workers | ||
) | ||
|
||
if not args.upload_name: | ||
output_name = f"{split}-00000-of-00001.parquet" | ||
new_split.to_parquet(output_name) | ||
else: | ||
token = args.token or os.environ.get("HF_TOKEN") | ||
new_split.push_to_hub( | ||
args.upload_name, | ||
config_name=args.dataset_subset, | ||
split=split, | ||
token=token, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
main(simple_parsing.parse(TtsArgs)) |