From cabfa0a61ace45af0923f8f87acd57a958326f14 Mon Sep 17 00:00:00 2001 From: shibe2 Date: Mon, 20 Nov 2023 14:57:26 +0400 Subject: [PATCH] server: Fix handling of characters that span multiple tokens when streaming --- examples/server/server.cpp | 39 +++++++++++++++++++------------------- 1 file changed, 19 insertions(+), 20 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index d0cd8e1cdb211..39d1e83d18575 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -376,7 +376,6 @@ struct llama_client_slot int32_t num_prompt_tokens = 0; int32_t num_prompt_tokens_processed = 0; - int32_t multibyte_pending = 0; json prompt; std::string generated_text; @@ -425,7 +424,6 @@ struct llama_client_slot stopped_word = false; stopped_limit = false; stopping_word = ""; - multibyte_pending = 0; n_past = 0; sent_count = 0; sent_token_probs_index = 0; @@ -992,35 +990,36 @@ struct llama_server_context slot.generated_text += token_str; slot.has_next_token = true; - if (slot.multibyte_pending > 0) + // check if there is incomplete UTF-8 character at the end + bool incomplete = false; + for (unsigned i = 1; i < 5 && i <= slot.generated_text.size(); ++i) { - slot.multibyte_pending -= token_str.size(); - } - else if (token_str.size() == 1) - { - const char c = token_str[0]; - // 2-byte characters: 110xxxxx 10xxxxxx + unsigned char c = slot.generated_text[slot.generated_text.size() - i]; + if ((c & 0xC0) == 0x80) + { + // continuation byte: 10xxxxxx + continue; + } if ((c & 0xE0) == 0xC0) { - slot.multibyte_pending = 1; - // 3-byte characters: 1110xxxx 10xxxxxx 10xxxxxx + // 2-byte character: 110xxxxx ... + incomplete = i < 2; } else if ((c & 0xF0) == 0xE0) { - slot.multibyte_pending = 2; - // 4-byte characters: 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx + // 3-byte character: 1110xxxx ... + incomplete = i < 3; } else if ((c & 0xF8) == 0xF0) { - slot.multibyte_pending = 3; - } - else - { - slot.multibyte_pending = 0; + // 4-byte character: 11110xxx ... + incomplete = i < 4; } + // else 1-byte character or invalid byte + break; } - if (slot.multibyte_pending == 0) + if (!incomplete) { size_t pos = std::min(slot.sent_count, slot.generated_text.size()); const std::string str_test = slot.generated_text.substr(pos); @@ -1055,7 +1054,7 @@ struct llama_server_context } } - if (slot.multibyte_pending > 0 && !slot.has_next_token) + if (incomplete) { slot.has_next_token = true; }