Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added rudimentary support for outetts v0.3 500m and 1b models #11287

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 31 additions & 16 deletions examples/tts/tts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ static std::string replace_numbers_with_words(const std::string & input_text) {
}

// Based on: https://github.com/edwko/OuteTTS/blob/a613e79c489d8256dd657ea9168d78de75895d82/outetts/version/v1/prompt_processor.py#L39
static std::string process_text(const std::string & text) {
static std::string process_text(const std::string & text, bool is_version_0_3) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw to check if the version is 0.3, you can use:

bool is_version_0_3 = common_get_builtin_chat_template(model) == "outetts-0.3"

@edwko I planned to add this as a dedicated GGUF meta key, but turns out I still not have the time to implement this. I'll try to do this in next week! And btw congrats for the release of v0.3 😄


// For now I skipped text romanization as I am unsure how to handle
// uroman and MeCab implementations in C++
Expand Down Expand Up @@ -401,7 +401,7 @@ static std::string process_text(const std::string & text) {
if (c == ' ') {
prompt_clean += "<|text_sep|>";
*/
processed_text = std::regex_replace(processed_text, std::regex(R"(\s)"), "<|text_sep|>");
processed_text = std::regex_replace(processed_text, std::regex(R"(\s)"), is_version_0_3?"<|space|>":"<|text_sep|>");

return processed_text;
}
Expand All @@ -425,8 +425,7 @@ static void prompt_init(llama_tokens & prompt, const llama_vocab * vocab) {
prompt_add(prompt, vocab, "<|im_start|>\n", true, true);
}

static std::vector<llama_token> prepare_guide_tokens(const llama_vocab * vocab, const std::string & str) {
const std::string& delimiter = "<|text_sep|>";
static std::vector<llama_token> prepare_guide_tokens(const llama_vocab * vocab, const std::string & str, const std::string& delimiter) {

std::vector<llama_token> result;
size_t start = 0;
Expand Down Expand Up @@ -523,6 +522,11 @@ int main(int argc, char ** argv) {
std::vector<llama_token> codes;
std::vector<llama_token> guide_tokens;

//determine OuteTTS version and vocab code offset. v0.2 does not have <|space|>, but v0.3 does
const bool is_version_0_3 = (common_get_builtin_chat_template(model_ttc) == "outetts-0.3");
//determine the offset of the first audio code token
const int cts_offset = common_tokenize(vocab,"<|0|>",false,true)[0];

// process prompt and generate voice codes
{
LOG_INF("%s: constructing prompt ..\n", __func__);
Expand All @@ -531,13 +535,17 @@ int main(int argc, char ** argv) {

prompt_init(prompt_inp, vocab);

prompt_add(prompt_inp, vocab, "<|text_start|>the<|text_sep|>overall<|text_sep|>package<|text_sep|>from<|text_sep|>just<|text_sep|>two<|text_sep|>people<|text_sep|>is<|text_sep|>pretty<|text_sep|>remarkable<|text_sep|>sure<|text_sep|>i<|text_sep|>have<|text_sep|>some<|text_sep|>critiques<|text_sep|>about<|text_sep|>some<|text_sep|>of<|text_sep|>the<|text_sep|>gameplay<|text_sep|>aspects<|text_sep|>but<|text_sep|>its<|text_sep|>still<|text_sep|>really<|text_sep|>enjoyable<|text_sep|>and<|text_sep|>it<|text_sep|>looks<|text_sep|>lovely<|text_sep|>", false, true);
if (is_version_0_3) {
prompt_add(prompt_inp, vocab, "<|text_start|>the<|space|>overall<|space|>package<|space|>from<|space|>just<|space|>two<|space|>people<|space|>is<|space|>pretty<|space|>remarkable<|space|>sure<|space|>i<|space|>have<|space|>some<|space|>critiques<|space|>about<|space|>some<|space|>of<|space|>the<|space|>gameplay<|space|>aspects<|space|>but<|space|>its<|space|>still<|space|>really<|space|>enjoyable<|space|>and<|space|>it<|space|>looks<|space|>lovely<|space|>", false, true);
} else {
prompt_add(prompt_inp, vocab, "<|text_start|>the<|text_sep|>overall<|text_sep|>package<|text_sep|>from<|text_sep|>just<|text_sep|>two<|text_sep|>people<|text_sep|>is<|text_sep|>pretty<|text_sep|>remarkable<|text_sep|>sure<|text_sep|>i<|text_sep|>have<|text_sep|>some<|text_sep|>critiques<|text_sep|>about<|text_sep|>some<|text_sep|>of<|text_sep|>the<|text_sep|>gameplay<|text_sep|>aspects<|text_sep|>but<|text_sep|>its<|text_sep|>still<|text_sep|>really<|text_sep|>enjoyable<|text_sep|>and<|text_sep|>it<|text_sep|>looks<|text_sep|>lovely<|text_sep|>", false, true);
}

// convert the input text into the necessary format expected by OuteTTS
{
std::string prompt_clean = process_text(params.prompt);
std::string prompt_clean = process_text(params.prompt, is_version_0_3);
if (params.vocoder.use_guide_tokens) {
guide_tokens = prepare_guide_tokens(vocab, prompt_clean);
guide_tokens = prepare_guide_tokens(vocab, prompt_clean, is_version_0_3?"<|space|>":"<|text_sep|>");
}

LOG_INF("%s: prompt: '%s'\n", __func__, prompt_clean.c_str());
Expand All @@ -549,8 +557,8 @@ int main(int argc, char ** argv) {

// disabled to save time on tokenizing each time
// TODO: load voices from the json files
#if 0
const std::string voice_data = R"(<|audio_start|>
#if 1
std::string voice_data = R"(<|audio_start|>
the<|t_0.08|><|code_start|><|257|><|740|><|636|><|913|><|788|><|1703|><|code_end|>
overall<|t_0.36|><|code_start|><|127|><|201|><|191|><|774|><|700|><|532|><|1056|><|557|><|798|><|298|><|1741|><|747|><|1662|><|1617|><|1702|><|1527|><|368|><|1588|><|1049|><|1008|><|1625|><|747|><|1576|><|728|><|1019|><|1696|><|1765|><|code_end|>
package<|t_0.56|><|code_start|><|935|><|584|><|1319|><|627|><|1016|><|1491|><|1344|><|1117|><|1526|><|1040|><|239|><|1435|><|951|><|498|><|723|><|1180|><|535|><|789|><|1649|><|1637|><|78|><|465|><|1668|><|901|><|595|><|1675|><|117|><|1009|><|1667|><|320|><|840|><|79|><|507|><|1762|><|1508|><|1228|><|1768|><|802|><|1450|><|1457|><|232|><|639|><|code_end|>
Expand Down Expand Up @@ -582,12 +590,19 @@ it<|t_0.09|><|code_start|><|848|><|1366|><|395|><|1601|><|1513|><|593|><|1302|><
looks<|t_0.27|><|code_start|><|1281|><|1266|><|1755|><|572|><|248|><|1751|><|1257|><|695|><|1380|><|457|><|659|><|585|><|1315|><|1105|><|1776|><|736|><|24|><|736|><|654|><|1027|><|code_end|>
lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|1481|><|1721|><|1123|><|438|><|1246|><|1251|><|795|><|659|><|1381|><|1658|><|217|><|1772|><|562|><|952|><|107|><|1129|><|1112|><|467|><|550|><|1079|><|840|><|1615|><|1469|><|1380|><|168|><|917|><|836|><|1827|><|437|><|583|><|67|><|595|><|1087|><|1646|><|1493|><|1677|><|code_end|>)";

auto tmp = common_tokenize(vocab, voice_data, false, true);
printf("\n\n");
for (int i = 0; i < tmp.size(); ++i) {
printf("%d, ", tmp[i]);
if (is_version_0_3)
{
voice_data = std::regex_replace(voice_data, std::regex(R"(<\|code_start\|>)"), "");
voice_data = std::regex_replace(voice_data, std::regex(R"(<\|code_end\|>)"), "<|space|>");
}
printf("\n\n");

prompt_add(prompt_inp, vocab, voice_data, false, true);

// printf("\n\n");
// for (int i = 0; i < tmp.size(); ++i) {
// printf("%d, ", tmp[i]);
// }
// printf("\n\n");
#else
prompt_add(prompt_inp, llama_tokens {
151667, 198, 1782, 155780, 151669, 151929, 152412, 152308, 152585,
Expand Down Expand Up @@ -882,7 +897,7 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
}

// remove all non-audio tokens (i.e. < 151672 || > 155772)
codes.erase(std::remove_if(codes.begin(), codes.end(), [](llama_token t) { return t < 151672 || t > 155772; }), codes.end());
codes.erase(std::remove_if(codes.begin(), codes.end(), [cts_offset](llama_token t) { return t < cts_offset || t > (cts_offset+4100); }), codes.end());

{
const std::string inp_txt = common_detokenize(ctx_ttc, codes, true);
Expand All @@ -891,7 +906,7 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
}

for (auto & token : codes) {
token -= 151672;
token -= cts_offset;
}

const auto t_voc_start = ggml_time_us();
Expand Down
Loading