Skip to content

Commit

Permalink
Add tts_tool, for converting a HF dataset to audio (#12)
Browse files Browse the repository at this point in the history
* v1

* v2

* v1

* Polish up

* Delete bq.py

* args cleanup

* bugfixes

* Update Justfile

* simple_parsing

* concurrency
  • Loading branch information
juberti authored Jun 6, 2024
1 parent 6f387d0 commit 4ffb5ad
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 0 deletions.
3 changes: 3 additions & 0 deletions Justfile
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ infer *FLAGS:
eval *FLAGS:
just python -m ultravox.tools.eval_tool {{FLAGS}}
tts *FLAGS:
just python -m ultravox.tools.tts_tool {{FLAGS}}
mds *FLAGS:
just python -m ultravox.tools.mds_tool {{FLAGS}}
Expand Down
48 changes: 48 additions & 0 deletions ultravox/tools/tts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import io
import os
from typing import Optional
from xml.sax import saxutils

import numpy as np
import requests
import soundfile as sf


def _make_ssml(voice: str, text: str):
return f"""
<speak version="1.0" xml:lang="en-US">
<voice xml:lang="en-US" name="{voice}">
{saxutils.escape(text)}
</voice>
</speak>"""


class AzureTts:
DEFAULT_VOICE = "en-US-JennyNeural"

def __init__(self, voice: Optional[str] = None, sample_rate: int = 16000):
self._session = requests.Session()
self._voice = voice or self.DEFAULT_VOICE
self._sample_rate = sample_rate

def tts(self, text: str):
region = "westus"
api_key = os.environ.get("AZURE_TTS_API_KEY") or os.environ.get(
"AZURE_WESTUS_TTS_API_KEY"
)
output_format = f"raw-{self._sample_rate // 1000}khz-16bit-mono-pcm"
url = f"https://{region}.tts.speech.microsoft.com/cognitiveservices/v1"
headers = {
"Ocp-Apim-Subscription-Key": api_key,
"Content-Type": "application/ssml+xml",
"X-Microsoft-OutputFormat": output_format,
"User-Agent": "MyTTS",
}
body = _make_ssml(self._voice, text)
response = self._session.post(url, headers=headers, data=body)
response.raise_for_status()

pcm_array = np.frombuffer(response.content, dtype=np.int16)
wav_bytes = io.BytesIO()
sf.write(wav_bytes, pcm_array, self._sample_rate, format="wav")
return wav_bytes.getvalue()
85 changes: 85 additions & 0 deletions ultravox/tools/tts_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import dataclasses
import os
from concurrent import futures
from typing import Dict, Optional, Union

import datasets
import simple_parsing

from ultravox.tools import tts


# This script is used to generate audio samples from text using a TTS model.
# Ex: just tts -d google/boolq -c question -a audio -u fixie-ai/boolq-audio
@dataclasses.dataclass
class TtsArgs:
dataset_name: str = simple_parsing.field(alias="-d")
dataset_subset: str = simple_parsing.field(default="default", alias="-S")
dataset_split: Optional[str] = simple_parsing.field(default=None, alias="-s")
column_name: str = simple_parsing.field(default="question", alias="-c")
audio_column_name: Optional[str] = simple_parsing.field(default=None, alias="-a")
num_samples: Optional[int] = simple_parsing.field(default=None, alias="-n")
num_workers: int = simple_parsing.field(default=16, alias="-w")
voice: Optional[str] = simple_parsing.field(default=None, alias="-V")
sample_rate: int = simple_parsing.field(default=16000, alias="-r")
upload_name: Optional[str] = simple_parsing.field(default=None, alias="-u")
token: Optional[str] = simple_parsing.field(default=None, alias="-t")


def _tts_split(
tts_client: tts.AzureTts,
ds_split: datasets.IterableDataset,
col_name: str,
audio_col_name: str,
num_workers: int,
):
def get_text(val: Union[str, Dict[str, str]]) -> str:
return val["text"] if isinstance(val, dict) else val

def tts_batch(batch):
with futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
texts = [get_text(val) for val in batch[col_name]]
audio_futures = [executor.submit(tts_client.tts, text) for text in texts]
batch[audio_col_name] = [f.result() for f in audio_futures]
return batch

return ds_split.map(tts_batch, batched=True).cast_column(
audio_col_name, datasets.Audio(sampling_rate=tts_client._sample_rate)
)


def main(args: TtsArgs):
ds_name = args.dataset_name
col_name = args.column_name
audio_col_name = args.audio_column_name or f"{col_name}_audio"
tts_client = tts.AzureTts(voice=args.voice, sample_rate=args.sample_rate)

print(f'Loading dataset "{ds_name}", mapping "{col_name}" to "{audio_col_name}"...')
data_dict = datasets.load_dataset(
ds_name, args.dataset_subset, split=args.dataset_split
)
if args.dataset_split:
data_dict = {args.dataset_split: data_dict}
for split, ds_split in data_dict.items():
print(f'Processing split "{split}"...')
if args.num_samples:
ds_split = ds_split.select(range(args.num_samples))
new_split = _tts_split(
tts_client, ds_split, col_name, audio_col_name, args.num_workers
)

if not args.upload_name:
output_name = f"{split}-00000-of-00001.parquet"
new_split.to_parquet(output_name)
else:
token = args.token or os.environ.get("HF_TOKEN")
new_split.push_to_hub(
args.upload_name,
config_name=args.dataset_subset,
split=split,
token=token,
)


if __name__ == "__main__":
main(simple_parsing.parse(TtsArgs))

0 comments on commit 4ffb5ad

Please sign in to comment.