From db2ffd519dd5b16417003bcac5f9282f51c43b7e Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Sun, 30 Jun 2024 14:34:55 -0400 Subject: [PATCH 01/11] llama : fix mpt and olmo pre-tokenizer --- src/llama.cpp | 35 ++++++++++++++++++++++++++++------- 1 file changed, 28 insertions(+), 7 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index 2a4d73856fcd9..5082daaaebf2e 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -5170,6 +5170,28 @@ static void llm_load_vocab( vocab.token_to_id[word] = i; vocab.max_token_len = std::max(vocab.max_token_len, (int) word.size()); + // TODO: properly handle pre-normalized added_tokens and remove this + // handle space tokens with dual tokens, + // like the pre-normalized added_tokens + // of neox-style tokenizers (mpt, olmo, stablelm, etc) + if (word.find(' ') != std::string::npos) { + // same as in the internal `unicode_byte_encoding_process` + // TODO: extract and expose this in some unicode_* function + std::string text_utf; + auto utf_word = unicode_cpts_from_utf8(word); + for (size_t i = 0; i < utf_word.size(); ++i) { + text_utf += unicode_cpt_to_utf8(utf_word[i]); + } + + std::string encoded_token; + for (char & c : text_utf) { + encoded_token += unicode_byte_to_utf8(c); + } + + // override token id + vocab.token_to_id[encoded_token] = i; + } + auto & token_data = vocab.id_to_token[i]; token_data.text = std::move(word); token_data.score = scores ? scores[i] : 0.0f; @@ -13890,13 +13912,9 @@ struct llm_tokenizer_bpe { }; break; case LLAMA_VOCAB_PRE_TYPE_MPT: - // TODO: MPT pre-tokenization regexes are unknown - // the following are close, but not exact. run the following: - // ./bin/test-tokenizer-0 ../models/ggml-vocab-mpt.gguf - GGML_ASSERT("MPT pre-tokenization regexes are unknown - fixes needed"); + case LLAMA_VOCAB_PRE_TYPE_OLMO: regex_exprs = { - "\\s?\\p{L}+", - "\\s?\\p{P}+", + "[ ]{2,24}", // the spaces from the added_tokens are split separately "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)", }; break; @@ -13909,7 +13927,6 @@ struct llm_tokenizer_bpe { }; break; case LLAMA_VOCAB_PRE_TYPE_GPT2: - case LLAMA_VOCAB_PRE_TYPE_OLMO: regex_exprs = { "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)", }; @@ -13985,6 +14002,10 @@ struct llm_tokenizer_bpe { void tokenize(const std::string & text, std::vector & output) { int final_prev_index = -1; + // FIXME: pre-tokenize added_tokens (user-defined tokens) before other pre-tokenization + // ref: https://github.com/huggingface/tokenizers/blob/fdd26ba9a3f0c133427aab0423888cbde91362d7/tokenizers/src/tokenizer/mod.rs#L726 + // (useful for neox-style tokenizers) + const auto word_collection = unicode_regex_split(text, regex_exprs); symbols_final.clear(); From d5d30b20c3382bf88af7c3532fd78e1b9b8a49e4 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Sun, 7 Jul 2024 15:32:42 -0400 Subject: [PATCH 02/11] llama : pre-tokenize non-special user-defined tokens first --- src/llama.cpp | 54 ++++++++++++++------------------------ tests/test-tokenizer-0.cpp | 4 +-- 2 files changed, 21 insertions(+), 37 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index 2e712a9d62eed..3dfbf792b7fd5 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -5495,28 +5495,6 @@ static void llm_load_vocab( vocab.token_to_id[word] = i; vocab.max_token_len = std::max(vocab.max_token_len, (int) word.size()); - // TODO: properly handle pre-normalized added_tokens and remove this - // handle space tokens with dual tokens, - // like the pre-normalized added_tokens - // of neox-style tokenizers (mpt, olmo, stablelm, etc) - if (word.find(' ') != std::string::npos) { - // same as in the internal `unicode_byte_encoding_process` - // TODO: extract and expose this in some unicode_* function - std::string text_utf; - auto utf_word = unicode_cpts_from_utf8(word); - for (size_t i = 0; i < utf_word.size(); ++i) { - text_utf += unicode_cpt_to_utf8(utf_word[i]); - } - - std::string encoded_token; - for (char & c : text_utf) { - encoded_token += unicode_byte_to_utf8(c); - } - - // override token id - vocab.token_to_id[encoded_token] = i; - } - auto & token_data = vocab.id_to_token[i]; token_data.text = std::move(word); token_data.score = scores ? scores[i] : 0.0f; @@ -5534,6 +5512,13 @@ static void llm_load_vocab( default: token_data.attr = LLAMA_TOKEN_ATTR_UNDEFINED; break; } } + + if ((token_data.attr & LLAMA_TOKEN_ATTR_USER_DEFINED) && token_data.text.find('<') && token_data.text.rfind('>')) { + // Some models mark some added tokens which ought to be control tokens as not special. + // (e.g. command-r, command-r-plus, deepseek-coder) + // TODO: should this be fixed in the convert script instead? + token_data.attr = LLAMA_TOKEN_ATTR_CONTROL; + } } GGML_ASSERT(vocab.id_to_token.size() == vocab.token_to_id.size()); @@ -15426,13 +15411,6 @@ struct llm_tokenizer_bpe { "[0-9][0-9][0-9]", }; break; - case LLAMA_VOCAB_PRE_TYPE_MPT: - case LLAMA_VOCAB_PRE_TYPE_OLMO: - regex_exprs = { - "[ ]{2,24}", // the spaces from the added_tokens are split separately - "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)", - }; - break; case LLAMA_VOCAB_PRE_TYPE_STARCODER: case LLAMA_VOCAB_PRE_TYPE_REFACT: case LLAMA_VOCAB_PRE_TYPE_COMMAND_R: @@ -15442,6 +15420,8 @@ struct llm_tokenizer_bpe { }; break; case LLAMA_VOCAB_PRE_TYPE_GPT2: + case LLAMA_VOCAB_PRE_TYPE_MPT: + case LLAMA_VOCAB_PRE_TYPE_OLMO: case LLAMA_VOCAB_PRE_TYPE_JAIS: regex_exprs = { "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)", @@ -15523,10 +15503,6 @@ struct llm_tokenizer_bpe { void tokenize(const std::string & text, std::vector & output) { int final_prev_index = -1; - // FIXME: pre-tokenize added_tokens (user-defined tokens) before other pre-tokenization - // ref: https://github.com/huggingface/tokenizers/blob/fdd26ba9a3f0c133427aab0423888cbde91362d7/tokenizers/src/tokenizer/mod.rs#L726 - // (useful for neox-style tokenizers) - const auto word_collection = unicode_regex_split(text, regex_exprs); symbols_final.clear(); @@ -16192,12 +16168,20 @@ struct fragment_buffer_variant { // #define PRETOKENIZERDEBUG -static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list & buffer) { +static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list & buffer, bool parse_special) { // for each special token for (const llama_vocab::id special_id : vocab.cache_special_tokens) { const auto & data = vocab.id_to_token[special_id]; const auto & special_token = data.text; + if (!parse_special && (data.attr & LLAMA_TOKEN_ATTR_CONTROL)) { + // Only ignore control tokens when parse_special == false + continue; + // User-defined tokens are still pre-tokenized before everything else + // ref: https://github.com/huggingface/tokenizers/blob/fdd26ba9a3f0c133427aab0423888cbde91362d7/tokenizers/src/tokenizer/mod.rs#L726 + // This is mostly relevant for neox-style tokenizers (mpt, olmo, stablelm, etc.) + } + // for each text fragment std::forward_list::iterator it = buffer.begin(); while (it != buffer.end()) { @@ -16310,7 +16294,7 @@ static std::vector llama_tokenize_internal(const llama_vocab & if (!raw_text.empty()) { fragment_buffer.emplace_front(raw_text, 0, raw_text.length()); - if (parse_special) tokenizer_st_partition(vocab, fragment_buffer); + tokenizer_st_partition(vocab, fragment_buffer, parse_special); } switch (vocab.type) { diff --git a/tests/test-tokenizer-0.cpp b/tests/test-tokenizer-0.cpp index 1f04b6f34ad7e..0c2d7781bedf9 100644 --- a/tests/test-tokenizer-0.cpp +++ b/tests/test-tokenizer-0.cpp @@ -195,7 +195,7 @@ int main(int argc, char **argv) { const bool add_special = false; for (const auto & test_kv : k_tests) { - const std::vector res = llama_tokenize(ctx, test_kv.first, add_special, true); + const std::vector res = llama_tokenize(ctx, test_kv.first, add_special); printf("\n"); printf("src: '%s'\n", test_kv.first.c_str()); @@ -253,7 +253,7 @@ int main(int argc, char **argv) { { const auto t_start = ggml_time_us(); - res = llama_tokenize(ctx, text, add_special, true); + res = llama_tokenize(ctx, text, add_special); const auto t_end = ggml_time_us(); From 56df1fcdcb6e9abf74e11ea05741fa65dbc020be Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Sun, 7 Jul 2024 16:13:35 -0400 Subject: [PATCH 03/11] llama : fix detection of control-like user-defined tokens --- src/llama.cpp | 3 ++- tests/test-tokenizer-0.cpp | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index 3dfbf792b7fd5..1794ec2bd8b82 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -5513,7 +5513,8 @@ static void llm_load_vocab( } } - if ((token_data.attr & LLAMA_TOKEN_ATTR_USER_DEFINED) && token_data.text.find('<') && token_data.text.rfind('>')) { + if ((token_data.attr & LLAMA_TOKEN_ATTR_USER_DEFINED) && !token_data.text.empty() && + token_data.text.front() == '<' && token_data.text.back() == '>') { // Some models mark some added tokens which ought to be control tokens as not special. // (e.g. command-r, command-r-plus, deepseek-coder) // TODO: should this be fixed in the convert script instead? diff --git a/tests/test-tokenizer-0.cpp b/tests/test-tokenizer-0.cpp index 0c2d7781bedf9..d3d21331bfd3d 100644 --- a/tests/test-tokenizer-0.cpp +++ b/tests/test-tokenizer-0.cpp @@ -195,7 +195,7 @@ int main(int argc, char **argv) { const bool add_special = false; for (const auto & test_kv : k_tests) { - const std::vector res = llama_tokenize(ctx, test_kv.first, add_special); + const std::vector res = llama_tokenize(ctx, test_kv.first, add_special, false); printf("\n"); printf("src: '%s'\n", test_kv.first.c_str()); @@ -253,7 +253,7 @@ int main(int argc, char **argv) { { const auto t_start = ggml_time_us(); - res = llama_tokenize(ctx, text, add_special); + res = llama_tokenize(ctx, text, add_special, false); const auto t_end = ggml_time_us(); From 6e351e04252e5956432078f673d69b5f19de318d Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Sun, 7 Jul 2024 16:59:00 -0400 Subject: [PATCH 04/11] convert_hf : identify which user-defined tokens are control tokens Only used in _set_vocab_gpt2() for now. --- convert_hf_to_gguf.py | 17 +++++++++++++++-- src/llama.cpp | 8 -------- 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 6cea73f08b743..30f87a9fe0f0c 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -373,6 +373,18 @@ def from_model_architecture(cls, arch: str) -> type[Model]: except KeyError: raise NotImplementedError(f'Architecture {arch!r} not supported!') from None + def does_token_look_special(self, token: str) -> bool: + # Some models mark some added tokens which ought to be control tokens as not special. + # (e.g. command-r, command-r-plus, deepseek-coder, gemma{,-2}) + is_known_special = token in ( + "", # deepseek-coder + "", "<2mass>", "[@BOS@]", # gemma{,-2} + ) + # TODO: should these be marked as UNUSED instead? + is_known_special = is_known_special or (token.startswith("")) # gemma{,-2} + + return is_known_special or (token.startswith(("<|", "<|")) and token.endswith(("|>", "|>"))) + # used for GPT-2 BPE and WordPiece vocabs def get_vocab_base(self) -> tuple[list[str], list[int], str]: tokens: list[str] = [] @@ -393,8 +405,9 @@ def get_vocab_base(self) -> tuple[list[str], list[int], str]: tokens.append(f"[PAD{i}]") toktypes.append(gguf.TokenType.USER_DEFINED) elif reverse_vocab[i] in added_vocab: - tokens.append(reverse_vocab[i]) - if tokenizer.added_tokens_decoder[i].special: + token: str = reverse_vocab[i] + tokens.append(token) + if tokenizer.added_tokens_decoder[i].special or self.does_token_look_special(token): toktypes.append(gguf.TokenType.CONTROL) else: toktypes.append(gguf.TokenType.USER_DEFINED) diff --git a/src/llama.cpp b/src/llama.cpp index 1794ec2bd8b82..11147eb1159b9 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -5512,14 +5512,6 @@ static void llm_load_vocab( default: token_data.attr = LLAMA_TOKEN_ATTR_UNDEFINED; break; } } - - if ((token_data.attr & LLAMA_TOKEN_ATTR_USER_DEFINED) && !token_data.text.empty() && - token_data.text.front() == '<' && token_data.text.back() == '>') { - // Some models mark some added tokens which ought to be control tokens as not special. - // (e.g. command-r, command-r-plus, deepseek-coder) - // TODO: should this be fixed in the convert script instead? - token_data.attr = LLAMA_TOKEN_ATTR_CONTROL; - } } GGML_ASSERT(vocab.id_to_token.size() == vocab.token_to_id.size()); From f9d42c598bba0fd10568cbdc507d1de20a17244a Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Sun, 7 Jul 2024 23:28:38 -0400 Subject: [PATCH 05/11] convert_hf : identify more added control tokens for SPM tokenziers This makes Gemma and Gemma-2 tokenize pretty much EVERYTHING correctly, including HTML tags and consecutive spaces, but it unfortunately requires model re-conversion. There seems to be a weird behavior of the HF tokenizer for Gemma, which prefers to use the 16-space token over more lengthy space tokens, while using the SentencePiece tokenizer does not do this. (the implementation in llama.cpp has the same behavior as SentencePiece) * llama : fix wrong pre-tokenization of byte tokens --- convert_hf_to_gguf.py | 80 +++++++++++++++++++++------------- src/llama.cpp | 2 +- tests/test-tokenizer-random.py | 10 ++--- 3 files changed, 55 insertions(+), 37 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 30f87a9fe0f0c..da6d2ba9e2652 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -373,17 +373,28 @@ def from_model_architecture(cls, arch: str) -> type[Model]: except KeyError: raise NotImplementedError(f'Architecture {arch!r} not supported!') from None - def does_token_look_special(self, token: str) -> bool: + def does_token_look_special(self, token: str | bytes) -> bool: + if isinstance(token, (bytes, bytearray)): + token_text = token.decode(encoding="utf-8") + elif isinstance(token, memoryview): + token_text = token.tobytes().decode(encoding="utf-8") + else: + token_text = token + # Some models mark some added tokens which ought to be control tokens as not special. # (e.g. command-r, command-r-plus, deepseek-coder, gemma{,-2}) - is_known_special = token in ( + seems_special = token_text in ( "", # deepseek-coder "", "<2mass>", "[@BOS@]", # gemma{,-2} ) - # TODO: should these be marked as UNUSED instead? - is_known_special = is_known_special or (token.startswith("")) # gemma{,-2} - return is_known_special or (token.startswith(("<|", "<|")) and token.endswith(("|>", "|>"))) + seems_special = seems_special or (token_text.startswith("<|") and token_text.endswith("|>")) + seems_special = seems_special or (token_text.startswith("<|") and token_text.endswith("|>")) # deepseek-coder + + # TODO: should these be marked as UNUSED instead? (maybe not) + seems_special = seems_special or (token_text.startswith("")) # gemma{,-2} + + return seems_special # used for GPT-2 BPE and WordPiece vocabs def get_vocab_base(self) -> tuple[list[str], list[int], str]: @@ -403,17 +414,18 @@ def get_vocab_base(self) -> tuple[list[str], list[int], str]: for i in range(vocab_size): if i not in reverse_vocab: tokens.append(f"[PAD{i}]") - toktypes.append(gguf.TokenType.USER_DEFINED) - elif reverse_vocab[i] in added_vocab: + toktypes.append(gguf.TokenType.UNUSED) + else: token: str = reverse_vocab[i] - tokens.append(token) - if tokenizer.added_tokens_decoder[i].special or self.does_token_look_special(token): - toktypes.append(gguf.TokenType.CONTROL) + if token in added_vocab: + if tokenizer.added_tokens_decoder[i].special or self.does_token_look_special(token): + toktypes.append(gguf.TokenType.CONTROL) + else: + token = token.replace(b"\xe2\x96\x81".decode("utf-8"), " ") # pre-normalize user-defined spaces + toktypes.append(gguf.TokenType.USER_DEFINED) else: - toktypes.append(gguf.TokenType.USER_DEFINED) - else: - tokens.append(reverse_vocab[i]) - toktypes.append(gguf.TokenType.NORMAL) + toktypes.append(gguf.TokenType.NORMAL) + tokens.append(token) return tokens, toktypes, tokpre @@ -572,7 +584,7 @@ def _set_vocab_qwen(self): for i in range(vocab_size): if i not in reverse_vocab: tokens.append(f"[PAD{i}]") - toktypes.append(gguf.TokenType.USER_DEFINED) + toktypes.append(gguf.TokenType.UNUSED) elif reverse_vocab[i] in added_vocab: tokens.append(reverse_vocab[i]) toktypes.append(gguf.TokenType.CONTROL) @@ -657,6 +669,25 @@ def _create_vocab_sentencepiece(self): scores[token_id] = -1000.0 toktypes[token_id] = SentencePieceTokenTypes.USER_DEFINED + tokenizer_config_file = self.dir_model / 'tokenizer_config.json' + if tokenizer_config_file.is_file(): + with open(tokenizer_config_file, "r", encoding="utf-8") as f: + tokenizer_config_json = json.load(f) + added_tokens_decoder = tokenizer_config_json.get("added_tokens_decoder", {}) + for token_id, token_data in added_tokens_decoder.items(): + token_id = int(token_id) + token: str = token_data["content"] + if toktypes[token_id] != SentencePieceTokenTypes.UNKNOWN: + assert tokens[token_id] == token.encode("utf-8") + if token_data.get("special") or self.does_token_look_special(token): + toktypes[token_id] = SentencePieceTokenTypes.CONTROL + else: + token = token.replace(b"\xe2\x96\x81".decode("utf-8"), " ") # pre-normalize user-defined spaces + toktypes[token_id] = SentencePieceTokenTypes.USER_DEFINED + + scores[token_id] = -1000.0 + tokens[token_id] = token.encode("utf-8") + if vocab_size > len(tokens): pad_count = vocab_size - len(tokens) logger.debug(f"Padding vocab with {pad_count} token(s) - [PAD1] through [PAD{pad_count}]") @@ -1280,7 +1311,7 @@ def set_vocab(self): if (self.dir_model / "tokenizer.json").is_file(): self._set_vocab_gpt2() else: - # StableLM 2 1.6B uses a vocab in a similar format to Qwen's vocab + # StableLM 2 1.6B used to have a vocab in a similar format to Qwen's vocab self._set_vocab_qwen() def set_gguf_parameters(self): @@ -1592,7 +1623,6 @@ def set_gguf_parameters(self): self.gguf_writer.add_rope_freq_base(attn_config["rope_theta"]) self.gguf_writer.add_clamp_kqv(attn_config["clip_qkv"]) - self.gguf_writer.add_file_type(self.ftype) self.gguf_writer.add_expert_count(ffn_config["moe_num_experts"]) self.gguf_writer.add_expert_used_count(ffn_config["moe_top_k"]) @@ -2412,19 +2442,7 @@ class Gemma2Model(Model): model_arch = gguf.MODEL_ARCH.GEMMA2 def set_vocab(self): - tokens, scores, toktypes = self._create_vocab_sentencepiece() - # hack: This is required so that we can properly use start/end-of-turn for chat template - for i in range(108): - # including , , - toktypes[i] = SentencePieceTokenTypes.CONTROL - self.gguf_writer.add_tokenizer_model("llama") - self.gguf_writer.add_tokenizer_pre("default") - self.gguf_writer.add_token_list(tokens) - self.gguf_writer.add_token_scores(scores) - self.gguf_writer.add_token_types(toktypes) - - special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens)) - special_vocab.add_to_gguf(self.gguf_writer) + self._set_vocab_sentencepiece() self.gguf_writer.add_add_space_prefix(False) @@ -3318,7 +3336,7 @@ def set_vocab(self): for i in range(vocab_size): if i not in reverse_vocab: tokens.append(f"[PAD{i}]") - toktypes.append(gguf.TokenType.USER_DEFINED) + toktypes.append(gguf.TokenType.UNUSED) elif reverse_vocab[i] in added_vocab: tokens.append(reverse_vocab[i]) if tokenizer.added_tokens_decoder[i].special: diff --git a/src/llama.cpp b/src/llama.cpp index 11147eb1159b9..c30d0adfe00b4 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -5640,7 +5640,7 @@ static void llm_load_vocab( // build special tokens cache { for (llama_vocab::id id = 0; id < (llama_vocab::id)n_vocab; ++id) { - if (!(vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_NORMAL)) { + if (vocab.id_to_token[id].attr & (LLAMA_TOKEN_ATTR_CONTROL | LLAMA_TOKEN_ATTR_USER_DEFINED)) { vocab.cache_special_tokens.push_back(id); } } diff --git a/tests/test-tokenizer-random.py b/tests/test-tokenizer-random.py index c50a8ca32f657..cdfc2b12c1178 100644 --- a/tests/test-tokenizer-random.py +++ b/tests/test-tokenizer-random.py @@ -20,7 +20,7 @@ from typing_extensions import Buffer import cffi -from transformers import AutoTokenizer +from transformers import AutoTokenizer, PreTrainedTokenizer logger = logging.getLogger("test-tokenizer-random") @@ -129,7 +129,7 @@ def decode(self, ids: list[int]) -> str: class TokenizerGroundtruth (Tokenizer): def __init__(self, dir_tokenizer: str): - self.model = AutoTokenizer.from_pretrained(dir_tokenizer) + self.model: PreTrainedTokenizer = AutoTokenizer.from_pretrained(dir_tokenizer) # guess BOS and EOS ids = self.encode("a") assert 1 <= len(ids) <= 3 @@ -143,7 +143,7 @@ def __init__(self, dir_tokenizer: str): self.vocab = list(sorted(self.vocab)) # tokens and lists self.special_tokens = list(self.model.all_special_tokens) - self.added_tokens = list(self.model.added_tokens_encoder) + self.added_tokens = self.model.batch_decode(self.model.added_tokens_encoder.values(), skip_special_tokens=False) self.bos_token = self.model.bos_token self.eos_token = self.model.eos_token @@ -458,8 +458,8 @@ def check_detokenizer(text: str, text1: str, text2: str) -> bool: i = find_first_mismatch(ids1, ids2) ids1 = list(ids1)[max(0, i - 2) : i + 5 + 1] ids2 = list(ids2)[max(0, i - 2) : i + 5 + 1] - logger.error(" Expected: " + str(ids1)) - logger.error(" Result: " + str(ids2)) + logger.error(" Expected: " + str(ids1) + f" {[tokenizer1.decode([id]) for id in ids1]}") + logger.error(" Result: " + str(ids2) + f" {[tokenizer2.decode([id]) for id in ids2]}") encode_errors += 1 logger.error(f" {encode_errors=}") if decode_errors < MAX_ERRORS and not check_detokenizer(text, text1, text2): From 31a1b0eeaa2c690f63772844fdac1ac24ed024c8 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Mon, 8 Jul 2024 16:34:39 -0400 Subject: [PATCH 06/11] llama : fix Viking pre-tokenizer regex The order was previously wrong, which caused errors in some tests. --- src/llama.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llama.cpp b/src/llama.cpp index c30d0adfe00b4..b652762d29000 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -15440,8 +15440,8 @@ struct llm_tokenizer_bpe { break; case LLAMA_VOCAB_PRE_TYPE_VIKING: regex_exprs = { - "\\p{N}", " ?[^(\\s|.,!?…。,、।۔،)]+", + "\\p{N}", }; break; default: From d6fe269ced93d45783a3b37c3cb20554264e5578 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Mon, 8 Jul 2024 18:13:16 -0400 Subject: [PATCH 07/11] llama : fix command-r detokenization --- src/llama.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/src/llama.cpp b/src/llama.cpp index b652762d29000..3509ff59948e8 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -5407,6 +5407,7 @@ static void llm_load_vocab( } else if ( tokenizer_pre == "command-r") { vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_COMMAND_R; + vocab.tokenizer_clean_spaces = false; } else if ( tokenizer_pre == "qwen2") { vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_QWEN2; From d4df785868a7a638b20fdf9b9f3e34bc48cdcae3 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Mon, 8 Jul 2024 21:09:52 -0400 Subject: [PATCH 08/11] convert_hf : reduce usages of the UNKNOWN token type --- convert_hf_to_gguf.py | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index da6d2ba9e2652..b2bfb695b92aa 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -634,7 +634,7 @@ def _create_vocab_sentencepiece(self): tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)] scores: list[float] = [-10000.0] * vocab_size - toktypes: list[int] = [SentencePieceTokenTypes.UNKNOWN] * vocab_size + toktypes: list[int] = [SentencePieceTokenTypes.UNUSED] * vocab_size for token_id in range(tokenizer.vocab_size()): piece = tokenizer.IdToPiece(token_id) @@ -677,7 +677,7 @@ def _create_vocab_sentencepiece(self): for token_id, token_data in added_tokens_decoder.items(): token_id = int(token_id) token: str = token_data["content"] - if toktypes[token_id] != SentencePieceTokenTypes.UNKNOWN: + if toktypes[token_id] != SentencePieceTokenTypes.UNUSED: assert tokens[token_id] == token.encode("utf-8") if token_data.get("special") or self.does_token_look_special(token): toktypes[token_id] = SentencePieceTokenTypes.CONTROL @@ -1916,7 +1916,7 @@ def set_vocab(self): tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)] scores: list[float] = [-10000.0] * vocab_size - toktypes: list[int] = [SentencePieceTokenTypes.UNKNOWN] * vocab_size + toktypes: list[int] = [SentencePieceTokenTypes.UNUSED] * vocab_size for token_id in range(tokenizer.vocab_size()): @@ -1961,7 +1961,7 @@ def set_vocab(self): for token_id, foken_data in added_tokens_decoder.items(): token_id = int(token_id) token = foken_data["content"].encode("utf-8") - if toktypes[token_id] != SentencePieceTokenTypes.UNKNOWN: + if toktypes[token_id] != SentencePieceTokenTypes.UNUSED: assert tokens[token_id] == token tokens[token_id] = token scores[token_id] = -1000.0 @@ -1977,7 +1977,7 @@ def set_vocab(self): for foken_data in added_tokens: token_id = int(foken_data["id"]) token = foken_data["content"].encode("utf-8") - if toktypes[token_id] != SentencePieceTokenTypes.UNKNOWN: + if toktypes[token_id] != SentencePieceTokenTypes.UNUSED: assert tokens[token_id] == token tokens[token_id] = token scores[token_id] = -1000.0 @@ -2766,7 +2766,7 @@ def set_vocab(self): tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)] scores: list[float] = [-10000.0] * vocab_size - toktypes: list[int] = [SentencePieceTokenTypes.UNKNOWN] * vocab_size + toktypes: list[int] = [SentencePieceTokenTypes.UNUSED] * vocab_size for token_id in range(tokenizer.vocab_size()): @@ -3021,7 +3021,7 @@ def set_vocab(self): tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)] scores: list[float] = [-10000.0] * vocab_size - toktypes: list[int] = [SentencePieceTokenTypes.UNKNOWN] * vocab_size + toktypes: list[int] = [SentencePieceTokenTypes.UNUSED] * vocab_size for token_id in range(tokenizer.vocab_size()): piece = tokenizer.IdToPiece(token_id) @@ -3239,15 +3239,14 @@ def set_vocab_chatglm3(self): if len(piece) != 0 and token_id < tokenizer.tokenizer.sp_model.vocab_size(): score = tokenizer.tokenizer.sp_model.get_score(token_id) - if len(piece) == 0: - text = f"[PAD{token_id}]".encode("utf-8") - if token_id >= tokenizer.tokenizer.sp_model.vocab_size(): if piece in special_tokens: - # show special tokens in prompt - toktype = SentencePieceTokenTypes.USER_DEFINED + toktype = SentencePieceTokenTypes.CONTROL + elif len(piece) == 0: + text = f"[PAD{token_id}]".encode("utf-8") + toktype = SentencePieceTokenTypes.UNUSED else: - toktype = SentencePieceTokenTypes.UNKNOWN + toktype = SentencePieceTokenTypes.USER_DEFINED tokens.append(text) scores.append(score) toktypes.append(toktype) From 98edea60bcd67458d70138c13c428414cd8ec63a Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Mon, 8 Jul 2024 21:23:19 -0400 Subject: [PATCH 09/11] llama : add UNKNOWN tokens in the special tokens cache --- src/llama.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index 3509ff59948e8..a2e8d62fc1c19 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -5641,7 +5641,7 @@ static void llm_load_vocab( // build special tokens cache { for (llama_vocab::id id = 0; id < (llama_vocab::id)n_vocab; ++id) { - if (vocab.id_to_token[id].attr & (LLAMA_TOKEN_ATTR_CONTROL | LLAMA_TOKEN_ATTR_USER_DEFINED)) { + if (vocab.id_to_token[id].attr & (LLAMA_TOKEN_ATTR_CONTROL | LLAMA_TOKEN_ATTR_USER_DEFINED | LLAMA_TOKEN_ATTR_UNKNOWN)) { vocab.cache_special_tokens.push_back(id); } } @@ -16168,8 +16168,8 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list< const auto & data = vocab.id_to_token[special_id]; const auto & special_token = data.text; - if (!parse_special && (data.attr & LLAMA_TOKEN_ATTR_CONTROL)) { - // Only ignore control tokens when parse_special == false + if (!parse_special && (data.attr & (LLAMA_TOKEN_ATTR_CONTROL | LLAMA_TOKEN_ATTR_UNKNOWN))) { + // Ignore control and unknown tokens when parse_special == false continue; // User-defined tokens are still pre-tokenized before everything else // ref: https://github.com/huggingface/tokenizers/blob/fdd26ba9a3f0c133427aab0423888cbde91362d7/tokenizers/src/tokenizer/mod.rs#L726 From 1caa20fc7a4bd0eac1cc26e5c7262c3dadeaf952 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Wed, 10 Jul 2024 17:33:04 -0400 Subject: [PATCH 10/11] convert_hf : reduce usages of UNKNOWN for InternLM2 This makes the changes from #8321 more consistent with the other changes made here. --- convert_hf_to_gguf.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 0236166b32921..c15c126eb52df 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -2189,7 +2189,7 @@ def set_vocab(self): toktype = SentencePieceTokenTypes.BYTE # take care of ununsed raw token if piece.startswith('[UNUSED'): - toktype = SentencePieceTokenTypes.UNKNOWN + toktype = SentencePieceTokenTypes.UNUSED tokens.append(text) scores.append(score) @@ -2219,7 +2219,7 @@ def set_vocab(self): if token == chat_eos_token: chat_eos_token_id = token_id token = token.encode("utf-8") - if toktypes[token_id] != SentencePieceTokenTypes.UNKNOWN: + if toktypes[token_id] != SentencePieceTokenTypes.UNUSED: assert(tokens[token_id] == token) tokens[token_id] = token scores[token_id] = -1000.0 @@ -2238,7 +2238,7 @@ def set_vocab(self): if token == chat_eos_token: chat_eos_token_id = token_id token = token.encode("utf-8") - if toktypes[token_id] != SentencePieceTokenTypes.UNKNOWN: + if toktypes[token_id] != SentencePieceTokenTypes.UNUSED: assert(tokens[token_id] == token) tokens[token_id] = token scores[token_id] = -1000.0 From 59ce85318aeace3d100395b86061fd10b3af6f54 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Sat, 13 Jul 2024 01:03:32 -0400 Subject: [PATCH 11/11] test-tokenizer-random : reduce potential confilcts with #8379 * test-tokenizer-random : add a failing edge case for falcon --- tests/test-tokenizer-random.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/test-tokenizer-random.py b/tests/test-tokenizer-random.py index cdfc2b12c1178..9ebe6c89185a3 100644 --- a/tests/test-tokenizer-random.py +++ b/tests/test-tokenizer-random.py @@ -232,6 +232,7 @@ def generator_custom_text_edge_cases() -> Iterator[str]: 'a\na', # bert fail '"`', # falcon ' \u2e4e', # falcon + '\n\x0b ', # falcon 'a\xa0\xa0\x00b', # jina-v2-es 'one ', # jina-v2-es lstrip=true 'a b', # rstrip phi-3 @@ -458,8 +459,8 @@ def check_detokenizer(text: str, text1: str, text2: str) -> bool: i = find_first_mismatch(ids1, ids2) ids1 = list(ids1)[max(0, i - 2) : i + 5 + 1] ids2 = list(ids2)[max(0, i - 2) : i + 5 + 1] - logger.error(" Expected: " + str(ids1) + f" {[tokenizer1.decode([id]) for id in ids1]}") - logger.error(" Result: " + str(ids2) + f" {[tokenizer2.decode([id]) for id in ids2]}") + logger.error(" Expected: " + str(ids1)) + logger.error(" Result: " + str(ids2)) encode_errors += 1 logger.error(f" {encode_errors=}") if decode_errors < MAX_ERRORS and not check_detokenizer(text, text1, text2):