Skip to content

Commit

Permalink
Explicitly specify protocol (http, ws, grpc) when calling tts(); include
Browse files Browse the repository at this point in the history
backward compatibility
  • Loading branch information
bryananderson committed Jan 13, 2025
1 parent e109720 commit c060389
Show file tree
Hide file tree
Showing 4 changed files with 188 additions and 73 deletions.
31 changes: 14 additions & 17 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,17 +79,14 @@ The `tts` method takes the following arguments:
- `text`: The text to be converted to speech; a string or list of strings.
- `options`: The options to use for the TTS request; a `TTSOptions` object [(see below)](#ttsoptions).
- `voice_engine`: The voice engine to use for the TTS request; a string (default `Play3.0-mini-http`).
- `PlayDialog-*`: Our large, expressive English model, which also supports multi-turn two-speaker dialogues.
- `PlayDialog-http`: Streaming and non-streaming audio over HTTP.
- `PlayDialog-ws`: Streaming audio over WebSockets.
- `PlayDialogMultilingual-*`: Our large, expressive multilingual model, which also supports multi-turn two-speaker dialogues.
- `PlayDialogMultilingual-http`: Streaming and non-streaming audio over HTTP.
- `PlayDialogMultilingual-ws`: Streaming audio over WebSockets.
- `Play3.0-mini-*`: Our small, fast multilingual model.
- `Play3.0-mini-http`: Streaming and non-streaming audio over HTTP.
- `Play3.0-mini-ws`: Streaming audio over WebSockets.
- `Play3.0-mini-grpc`: Streaming audio over gRPC. NOTE: This voice engine is ONLY available for Play On-Prem customers.
- `PlayHT2.0-turbo`: Our legacy English-only model, streaming audio over gRPC.
- `PlayDialog`: Our large, expressive English model, which also supports multi-turn two-speaker dialogues.
- `PlayDialogMultilingual`: Our large, expressive multilingual model, which also supports multi-turn two-speaker dialogues.
- `Play3.0-mini`: Our small, fast multilingual model.
- `PlayHT2.0-turbo`: Our legacy English-only model
- `protocol`: The protocol to use to communicate with the Play API (`http` by default except for `PlayHT2.0-turbo` which is `grpc` by default).
- `http`: Streaming and non-streaming audio over HTTP (supports `Play3.0-mini`, `PlayDialog`, and `PlayDialogMultilingual`).
- `ws`: Streaming audio over WebSockets (supports `Play3.0-mini`, `PlayDialog`, and `PlayDialogMultilingual`).
- `grpc`: Streaming audio over gRPC (supports `PlayHT2.0-turbo` for all, and `Play3.0-mini` ONLY for Play On-Prem customers).
- `streaming`: Whether or not to stream the audio in chunks (default True); non-streaming is only enabled for HTTP endpoints.

### TTSOptions
Expand Down Expand Up @@ -117,12 +114,12 @@ The `TTSOptions` class is used to specify the options for the TTS request. It ha
- The following options are inference-time hyperparameters of the text-to-speech model; if unset, the model will use default values chosen by Play.
- `temperature` (all models): The temperature of the model, a float.
- `top_p` (all models): The top_p of the model, a float.
- `text_guidance` (`Play3.0-mini-*` and `PlayHT2.0-turbo` only): The text_guidance of the model, a float.
- `voice_guidance` (`Play3.0-mini-*` and `PlayHT2.0-turbo` only): The voice_guidance of the model, a float.
- `style_guidance` (`Play3.0-mini-*` only): The style_guidance of the model, a float.
- `repetition_penalty` (`Play3.0-mini-*` and `PlayHT2.0-turbo` only): The repetition_penalty of the model, a float.
- `text_guidance` (`Play3.0-mini` and `PlayHT2.0-turbo` only): The text_guidance of the model, a float.
- `voice_guidance` (`Play3.0-mini` and `PlayHT2.0-turbo` only): The voice_guidance of the model, a float.
- `style_guidance` (`Play3.0-mini` only): The style_guidance of the model, a float.
- `repetition_penalty` (`Play3.0-mini` and `PlayHT2.0-turbo` only): The repetition_penalty of the model, a float.
- `disable_stabilization` (`PlayHT2.0-turbo` only): Disable the audio stabilization process, a boolean (default `False`).
- `language` (`Play3.0-*` and `PlayDialogMultilingual-*` only): The language of the text to be spoken, a `Language` enum value or `None` (default `ENGLISH`).
- `language` (`Play3.0` and `PlayDialogMultilingual` only): The language of the text to be spoken, a `Language` enum value or `None` (default `ENGLISH`).
- `AFRIKAANS`
- `ALBANIAN`
- `AMHARIC`
Expand Down Expand Up @@ -160,7 +157,7 @@ The `TTSOptions` class is used to specify the options for the TTS request. It ha
- `UKRAINIAN`
- `URDU`
- `XHOSA`
- The following options are additional inference-time hyperparameters which only apply to the `PlayDialog-*` and `PlayDialogMultilingual-*` models; if unset, the model will use default values chosen by Play.
- The following options are additional inference-time hyperparameters which only apply to the `PlayDialog` and `PlayDialogMultilingual` models; if unset, the model will use default values chosen by Play.
- `voice_2` (multi-turn dialogue only): The second voice to use for a multi-turn TTS request; a string.
- A URL pointing to a Play voice manifest file.
- `turn_prefix` (multi-turn dialogue only): The prefix for the first speaker's turns in a multi-turn TTS request; a string.
Expand Down
27 changes: 15 additions & 12 deletions pyht/async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,8 @@ async def stream_tts_input(
text_stream: Union[AsyncGenerator[str, None], AsyncIterable[str]],
options: TTSOptions,
voice_engine: Optional[str] = None,
protocol: Optional[str] = None,
streaming: bool = True
):
"""Stream input to Play via the text_stream object."""
buffer = io.StringIO()
Expand All @@ -239,37 +241,36 @@ async def stream_tts_input(
buffer.write(" ") # normalize word spacing.
if SENTENCE_END_REGEX.match(t) is None:
continue
async for data in self.tts(buffer.getvalue(), options, voice_engine):
async for data in self.tts(buffer.getvalue(), options, voice_engine, protocol, streaming):
yield data
buffer = io.StringIO()
# If text_stream closes, send all remaining text, regardless of sentence structure.
if buffer.tell() > 0:
async for data in self.tts(buffer.getvalue(), options, voice_engine):
async for data in self.tts(buffer.getvalue(), options, voice_engine, protocol, streaming):
yield data

def tts(
self,
text: Union[str, list[str]],
options: TTSOptions,
voice_engine: Optional[str] = None,
protocol: Optional[str] = None,
streaming: bool = True
) -> AsyncIterable[bytes]:
metrics = self._telemetry.start("tts-request")
try:
voice_engine, protocol = get_voice_engine_and_protocol(voice_engine)
voice_engine, protocol = get_voice_engine_and_protocol(voice_engine, protocol)

if protocol == "http":
return self._tts_http(text, options, voice_engine, metrics, streaming)
elif protocol == "ws":
if streaming:
return self._tts_ws(text, options, voice_engine, metrics)
else:
if not streaming:
raise ValueError("Non-streaming is not supported for WebSocket API")
return self._tts_ws(text, options, voice_engine, metrics)
elif protocol == "grpc":
if streaming:
return self._tts_grpc(text, options, voice_engine, metrics)
else:
if not streaming:
raise ValueError("Non-streaming is not supported for gRPC API")
return self._tts_grpc(text, options, voice_engine, metrics)
else:
raise ValueError(f"Unknown protocol {protocol}")
except Exception as e:
Expand Down Expand Up @@ -489,7 +490,8 @@ async def _tts_ws(
def get_stream_pair(
self,
options: TTSOptions,
voice_engine: Optional[str] = None
voice_engine: Optional[str] = None,
protocol: Optional[str] = None
) -> tuple['_InputStream', '_OutputStream']:
"""Get a linked pair of (input, output) streams.
Expand All @@ -498,7 +500,7 @@ def get_stream_pair(
"""
shared_q = asyncio.Queue()
return (
_InputStream(self, options, shared_q, voice_engine),
_InputStream(self, options, shared_q, voice_engine, protocol),
_OutputStream(shared_q)
)

Expand Down Expand Up @@ -587,11 +589,12 @@ def __init__(
options: TTSOptions,
q: asyncio.Queue[Optional[bytes]],
voice_engine: Optional[str],
protocol: Optional[str] = None
):
self._input = TextStream()

async def listen():
async for output in client.stream_tts_input(self._input, options, voice_engine):
async for output in client.stream_tts_input(self._input, options, voice_engine, protocol):
await q.put(output)
await q.put(None)

Expand Down
44 changes: 23 additions & 21 deletions pyht/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class HTTPFormat(Enum):
FORMAT_PCM = "pcm"


# PlayDialog-* and PlayDialogMultilingual-* only
# PlayDialog and PlayDialogMultilingual only
class CandidateRankingMethod(Enum):
# non-streaming only
DescriptionASRWithMeanProbRank = "description_asr_with_mean_prob"
Expand Down Expand Up @@ -185,21 +185,21 @@ class TTSOptions:
temperature: Optional[float] = None
top_p: Optional[float] = None

# only applies to Play3.0-* and PlayHT2.0-turbo
# only apply to Play3.0 and PlayHT2.0-turbo
text_guidance: Optional[float] = None
voice_guidance: Optional[float] = None
repetition_penalty: Optional[float] = None

# only applies to Play3.0-*
# only applies to Play3.0
style_guidance: Optional[float] = None

# only applies to PlayHT2.0-*
# only applies to PlayHT2.0
disable_stabilization: Optional[bool] = None

# only applies to Play3.0-* and PlayDialogMultilingual-*
# only applies to Play3.0 and PlayDialogMultilingual
language: Optional[Language] = None

# only applies to PlayDialog-* and PlayDialogMultilingual-*
# only apply to PlayDialog and PlayDialogMultilingual
# leave the _2 params None if generating single-speaker audio
voice_2: Optional[str] = None
turn_prefix: Optional[str] = None
Expand Down Expand Up @@ -293,7 +293,7 @@ def http_prepare_dict(text: List[str], options: TTSOptions, voice_engine: str) -
"language": options.language.value if options.language is not None else None,
"version": version,

# PlayDialog-* and PlayDialogMultilingual-*
# PlayDialog and PlayDialogMultilingual
# leave the _2 params None if generating single-speaker audio
"voice_2": options.voice_2,
"turn_prefix": options.turn_prefix,
Expand Down Expand Up @@ -506,7 +506,9 @@ def stream_tts_input(
self,
text_stream: Union[Generator[str, None, None], Iterable[str]],
options: TTSOptions,
voice_engine: Optional[str] = None
voice_engine: Optional[str] = None,
protocol: Optional[str] = None,
streaming: bool = True
) -> Iterable[bytes]:
"""Stream input to Play.ht via the text_stream object."""
buffer = io.StringIO()
Expand All @@ -516,35 +518,34 @@ def stream_tts_input(
buffer.write(" ") # normalize word spacing.
if SENTENCE_END_REGEX.match(t) is None:
continue
yield from self.tts(buffer.getvalue(), options, voice_engine)
yield from self.tts(buffer.getvalue(), options, voice_engine, protocol, streaming)
buffer = io.StringIO()
# If text_stream closes, send all remaining text, regardless of sentence structure.
if buffer.tell() > 0:
yield from self.tts(buffer.getvalue(), options, voice_engine)
yield from self.tts(buffer.getvalue(), options, voice_engine, protocol, streaming)

def tts(
self,
text: Union[str, List[str]],
options: TTSOptions,
voice_engine: Optional[str] = None,
protocol: Optional[str] = None,
streaming: bool = True
) -> Iterable[bytes]:
metrics = self._telemetry.start("tts-request")
try:
voice_engine, protocol = get_voice_engine_and_protocol(voice_engine)
voice_engine, protocol = get_voice_engine_and_protocol(voice_engine, protocol)

if protocol == "http":
return self._tts_http(text, options, voice_engine, metrics, streaming)
elif protocol == "ws":
if streaming:
return self._tts_ws(text, options, voice_engine, metrics)
else:
if not streaming:
raise ValueError("Non-streaming is not supported for WebSocket API")
return self._tts_ws(text, options, voice_engine, metrics)
elif protocol == "grpc":
if streaming:
return self._tts_grpc(text, options, voice_engine, metrics)
else:
if not streaming:
raise ValueError("Non-streaming is not supported for gRPC API")
return self._tts_grpc(text, options, voice_engine, metrics)
else:
raise ValueError(f"Unknown protocol {protocol}")
except Exception as e:
Expand Down Expand Up @@ -757,15 +758,16 @@ def _tts_ws(
def get_stream_pair(
self,
options: TTSOptions,
voice_engine: Optional[str] = None
voice_engine: Optional[str] = None,
protocol: Optional[str] = None
) -> Tuple['_InputStream', '_OutputStream']:
"""Get a linked pair of (input, output) streams.
These stream objects are thread-aware and safe to use in separate threads.
"""
shared_q = queue.Queue()
return (
_InputStream(self, options, shared_q, voice_engine),
_InputStream(self, options, shared_q, voice_engine, protocol),
_OutputStream(shared_q)
)

Expand Down Expand Up @@ -818,11 +820,11 @@ class _InputStream:
input_stream.done()
"""
def __init__(self, client: Client, options: TTSOptions, q: queue.Queue[Optional[bytes]],
voice_engine: Optional[str]):
voice_engine: Optional[str], protocol: Optional[str] = None):
self._input = TextStream()

def listen():
for output in client.stream_tts_input(self._input, options, voice_engine):
for output in client.stream_tts_input(self._input, options, voice_engine, protocol):
q.put(output)
q.put(None)

Expand Down
Loading

0 comments on commit c060389

Please sign in to comment.