From db2ffd519dd5b16417003bcac5f9282f51c43b7e Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Sun, 30 Jun 2024 14:34:55 -0400 Subject: [PATCH 01/11] llama : fix mpt and olmo pre-tokenizer --- src/llama.cpp | 35 ++++++++++++++++++++++++++++------- 1 file changed, 28 insertions(+), 7 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index 2a4d73856fcd93..5082daaaebf2e8 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -5170,6 +5170,28 @@ static void llm_load_vocab( vocab.token_to_id[word] = i; vocab.max_token_len = std::max(vocab.max_token_len, (int) word.size()); + // TODO: properly handle pre-normalized added_tokens and remove this + // handle space tokens with dual tokens, + // like the pre-normalized added_tokens + // of neox-style tokenizers (mpt, olmo, stablelm, etc) + if (word.find(' ') != std::string::npos) { + // same as in the internal `unicode_byte_encoding_process` + // TODO: extract and expose this in some unicode_* function + std::string text_utf; + auto utf_word = unicode_cpts_from_utf8(word); + for (size_t i = 0; i < utf_word.size(); ++i) { + text_utf += unicode_cpt_to_utf8(utf_word[i]); + } + + std::string encoded_token; + for (char & c : text_utf) { + encoded_token += unicode_byte_to_utf8(c); + } + + // override token id + vocab.token_to_id[encoded_token] = i; + } + auto & token_data = vocab.id_to_token[i]; token_data.text = std::move(word); token_data.score = scores ? scores[i] : 0.0f; @@ -13890,13 +13912,9 @@ struct llm_tokenizer_bpe { }; break; case LLAMA_VOCAB_PRE_TYPE_MPT: - // TODO: MPT pre-tokenization regexes are unknown - // the following are close, but not exact. run the following: - // ./bin/test-tokenizer-0 ../models/ggml-vocab-mpt.gguf - GGML_ASSERT("MPT pre-tokenization regexes are unknown - fixes needed"); + case LLAMA_VOCAB_PRE_TYPE_OLMO: regex_exprs = { - "\\s?\\p{L}+", - "\\s?\\p{P}+", + "[ ]{2,24}", // the spaces from the added_tokens are split separately "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)", }; break; @@ -13909,7 +13927,6 @@ struct llm_tokenizer_bpe { }; break; case LLAMA_VOCAB_PRE_TYPE_GPT2: - case LLAMA_VOCAB_PRE_TYPE_OLMO: regex_exprs = { "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)", }; @@ -13985,6 +14002,10 @@ struct llm_tokenizer_bpe { void tokenize(const std::string & text, std::vector & output) { int final_prev_index = -1; + // FIXME: pre-tokenize added_tokens (user-defined tokens) before other pre-tokenization + // ref: https://github.com/huggingface/tokenizers/blob/fdd26ba9a3f0c133427aab0423888cbde91362d7/tokenizers/src/tokenizer/mod.rs#L726 + // (useful for neox-style tokenizers) + const auto word_collection = unicode_regex_split(text, regex_exprs); symbols_final.clear(); From d5d30b20c3382bf88af7c3532fd78e1b9b8a49e4 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Sun, 7 Jul 2024 15:32:42 -0400 Subject: [PATCH 02/11] llama : pre-tokenize non-special user-defined tokens first --- src/llama.cpp | 54 ++++++++++++++------------------------ tests/test-tokenizer-0.cpp | 4 +-- 2 files changed, 21 insertions(+), 37 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index 2e712a9d62eed2..3dfbf792b7fd50 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -5495,28 +5495,6 @@ static void llm_load_vocab( vocab.token_to_id[word] = i; vocab.max_token_len = std::max(vocab.max_token_len, (int) word.size()); - // TODO: properly handle pre-normalized added_tokens and remove this - // handle space tokens with dual tokens, - // like the pre-normalized added_tokens - // of neox-style tokenizers (mpt, olmo, stablelm, etc) - if (word.find(' ') != std::string::npos) { - // same as in the internal `unicode_byte_encoding_process` - // TODO: extract and expose this in some unicode_* function - std::string text_utf; - auto utf_word = unicode_cpts_from_utf8(word); - for (size_t i = 0; i < utf_word.size(); ++i) { - text_utf += unicode_cpt_to_utf8(utf_word[i]); - } - - std::string encoded_token; - for (char & c : text_utf) { - encoded_token += unicode_byte_to_utf8(c); - } - - // override token id - vocab.token_to_id[encoded_token] = i; - } - auto & token_data = vocab.id_to_token[i]; token_data.text = std::move(word); token_data.score = scores ? scores[i] : 0.0f; @@ -5534,6 +5512,13 @@ static void llm_load_vocab( default: token_data.attr = LLAMA_TOKEN_ATTR_UNDEFINED; break; } } + + if ((token_data.attr & LLAMA_TOKEN_ATTR_USER_DEFINED) && token_data.text.find('<') && token_data.text.rfind('>')) { + // Some models mark some added tokens which ought to be control tokens as not special. + // (e.g. command-r, command-r-plus, deepseek-coder) + // TODO: should this be fixed in the convert script instead? + token_data.attr = LLAMA_TOKEN_ATTR_CONTROL; + } } GGML_ASSERT(vocab.id_to_token.size() == vocab.token_to_id.size()); @@ -15426,13 +15411,6 @@ struct llm_tokenizer_bpe { "[0-9][0-9][0-9]", }; break; - case LLAMA_VOCAB_PRE_TYPE_MPT: - case LLAMA_VOCAB_PRE_TYPE_OLMO: - regex_exprs = { - "[ ]{2,24}", // the spaces from the added_tokens are split separately - "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)", - }; - break; case LLAMA_VOCAB_PRE_TYPE_STARCODER: case LLAMA_VOCAB_PRE_TYPE_REFACT: case LLAMA_VOCAB_PRE_TYPE_COMMAND_R: @@ -15442,6 +15420,8 @@ struct llm_tokenizer_bpe { }; break; case LLAMA_VOCAB_PRE_TYPE_GPT2: + case LLAMA_VOCAB_PRE_TYPE_MPT: + case LLAMA_VOCAB_PRE_TYPE_OLMO: case LLAMA_VOCAB_PRE_TYPE_JAIS: regex_exprs = { "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)", @@ -15523,10 +15503,6 @@ struct llm_tokenizer_bpe { void tokenize(const std::string & text, std::vector & output) { int final_prev_index = -1; - // FIXME: pre-tokenize added_tokens (user-defined tokens) before other pre-tokenization - // ref: https://github.com/huggingface/tokenizers/blob/fdd26ba9a3f0c133427aab0423888cbde91362d7/tokenizers/src/tokenizer/mod.rs#L726 - // (useful for neox-style tokenizers) - const auto word_collection = unicode_regex_split(text, regex_exprs); symbols_final.clear(); @@ -16192,12 +16168,20 @@ struct fragment_buffer_variant { // #define PRETOKENIZERDEBUG -static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list & buffer) { +static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list & buffer, bool parse_special) { // for each special token for (const llama_vocab::id special_id : vocab.cache_special_tokens) { const auto & data = vocab.id_to_token[special_id]; const auto & special_token = data.text; + if (!parse_special && (data.attr & LLAMA_TOKEN_ATTR_CONTROL)) { + // Only ignore control tokens when parse_special == false + continue; + // User-defined tokens are still pre-tokenized before everything else + // ref: https://github.com/huggingface/tokenizers/blob/fdd26ba9a3f0c133427aab0423888cbde91362d7/tokenizers/src/tokenizer/mod.rs#L726 + // This is mostly relevant for neox-style tokenizers (mpt, olmo, stablelm, etc.) + } + // for each text fragment std::forward_list::iterator it = buffer.begin(); while (it != buffer.end()) { @@ -16310,7 +16294,7 @@ static std::vector llama_tokenize_internal(const llama_vocab & if (!raw_text.empty()) { fragment_buffer.emplace_front(raw_text, 0, raw_text.length()); - if (parse_special) tokenizer_st_partition(vocab, fragment_buffer); + tokenizer_st_partition(vocab, fragment_buffer, parse_special); } switch (vocab.type) { diff --git a/tests/test-tokenizer-0.cpp b/tests/test-tokenizer-0.cpp index 1f04b6f34ad7e4..0c2d7781bedf95 100644 --- a/tests/test-tokenizer-0.cpp +++ b/tests/test-tokenizer-0.cpp @@ -195,7 +195,7 @@ int main(int argc, char **argv) { const bool add_special = false; for (const auto & test_kv : k_tests) { - const std::vector res = llama_tokenize(ctx, test_kv.first, add_special, true); + const std::vector res = llama_tokenize(ctx, test_kv.first, add_special); printf("\n"); printf("src: '%s'\n", test_kv.first.c_str()); @@ -253,7 +253,7 @@ int main(int argc, char **argv) { { const auto t_start = ggml_time_us(); - res = llama_tokenize(ctx, text, add_special, true); + res = llama_tokenize(ctx, text, add_special); const auto t_end = ggml_time_us(); From 56df1fcdcb6e9abf74e11ea05741fa65dbc020be Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Sun, 7 Jul 2024 16:13:35 -0400 Subject: [PATCH 03/11] llama : fix detection of control-like user-defined tokens --- src/llama.cpp | 3 ++- tests/test-tokenizer-0.cpp | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index 3dfbf792b7fd50..1794ec2bd8b826 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -5513,7 +5513,8 @@ static void llm_load_vocab( } } - if ((token_data.attr & LLAMA_TOKEN_ATTR_USER_DEFINED) && token_data.text.find('<') && token_data.text.rfind('>')) { + if ((token_data.attr & LLAMA_TOKEN_ATTR_USER_DEFINED) && !token_data.text.empty() && + token_data.text.front() == '<' && token_data.text.back() == '>') { // Some models mark some added tokens which ought to be control tokens as not special. // (e.g. command-r, command-r-plus, deepseek-coder) // TODO: should this be fixed in the convert script instead? diff --git a/tests/test-tokenizer-0.cpp b/tests/test-tokenizer-0.cpp index 0c2d7781bedf95..d3d21331bfd3d1 100644 --- a/tests/test-tokenizer-0.cpp +++ b/tests/test-tokenizer-0.cpp @@ -195,7 +195,7 @@ int main(int argc, char **argv) { const bool add_special = false; for (const auto & test_kv : k_tests) { - const std::vector res = llama_tokenize(ctx, test_kv.first, add_special); + const std::vector res = llama_tokenize(ctx, test_kv.first, add_special, false); printf("\n"); printf("src: '%s'\n", test_kv.first.c_str()); @@ -253,7 +253,7 @@ int main(int argc, char **argv) { { const auto t_start = ggml_time_us(); - res = llama_tokenize(ctx, text, add_special); + res = llama_tokenize(ctx, text, add_special, false); const auto t_end = ggml_time_us(); From 6e351e04252e5956432078f673d69b5f19de318d Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Sun, 7 Jul 2024 16:59:00 -0400 Subject: [PATCH 04/11] convert_hf : identify which user-defined tokens are control tokens Only used in _set_vocab_gpt2() for now. --- convert_hf_to_gguf.py | 17 +++++++++++++++-- src/llama.cpp | 8 -------- 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 6cea73f08b7434..30f87a9fe0f0c9 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -373,6 +373,18 @@ def from_model_architecture(cls, arch: str) -> type[Model]: except KeyError: raise NotImplementedError(f'Architecture {arch!r} not supported!') from None + def does_token_look_special(self, token: str) -> bool: + # Some models mark some added tokens which ought to be control tokens as not special. + # (e.g. command-r, command-r-plus, deepseek-coder, gemma{,-2}) + is_known_special = token in ( + "", # deepseek-coder + "", "<2mass>", "[@BOS@]", # gemma{,-2} + ) + # TODO: should these be marked as UNUSED instead? + is_known_special = is_known_special or (token.startswith("")) # gemma{,-2} + + return is_known_special or (token.startswith(("<|", "<|")) and token.endswith(("|>", "|>"))) + # used for GPT-2 BPE and WordPiece vocabs def get_vocab_base(self) -> tuple[list[str], list[int], str]: tokens: list[str] = [] @@ -393,8 +405,9 @@ def get_vocab_base(self) -> tuple[list[str], list[int], str]: tokens.append(f"[PAD{i}]") toktypes.append(gguf.TokenType.USER_DEFINED) elif reverse_vocab[i] in added_vocab: - tokens.append(reverse_vocab[i]) - if tokenizer.added_tokens_decoder[i].special: + token: str = reverse_vocab[i] + tokens.append(token) + if tokenizer.added_tokens_decoder[i].special or self.does_token_look_special(token): toktypes.append(gguf.TokenType.CONTROL) else: toktypes.append(gguf.TokenType.USER_DEFINED) diff --git a/src/llama.cpp b/src/llama.cpp index 1794ec2bd8b826..11147eb1159b9d 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -5512,14 +5512,6 @@ static void llm_load_vocab( default: token_data.attr = LLAMA_TOKEN_ATTR_UNDEFINED; break; } } - - if ((token_data.attr & LLAMA_TOKEN_ATTR_USER_DEFINED) && !token_data.text.empty() && - token_data.text.front() == '<' && token_data.text.back() == '>') { - // Some models mark some added tokens which ought to be control tokens as not special. - // (e.g. command-r, command-r-plus, deepseek-coder) - // TODO: should this be fixed in the convert script instead? - token_data.attr = LLAMA_TOKEN_ATTR_CONTROL; - } } GGML_ASSERT(vocab.id_to_token.size() == vocab.token_to_id.size()); From f9d42c598bba0fd10568cbdc507d1de20a17244a Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Sun, 7 Jul 2024 23:28:38 -0400 Subject: [PATCH 05/11] convert_hf : identify more added control tokens for SPM tokenziers This makes Gemma and Gemma-2 tokenize pretty much EVERYTHING correctly, including HTML tags and consecutive spaces, but it unfortunately requires model re-conversion. There seems to be a weird behavior of the HF tokenizer for Gemma, which prefers to use the 16-space token over more lengthy space tokens, while using the SentencePiece tokenizer does not do this. (the implementation in llama.cpp has the same behavior as SentencePiece) * llama : fix wrong pre-tokenization of byte tokens --- convert_hf_to_gguf.py | 80 +++++++++++++++++++++------------- src/llama.cpp | 2 +- tests/test-tokenizer-random.py | 10 ++--- 3 files changed, 55 insertions(+), 37 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 30f87a9fe0f0c9..da6d2ba9e26529 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -373,17 +373,28 @@ def from_model_architecture(cls, arch: str) -> type[Model]: except KeyError: raise NotImplementedError(f'Architecture {arch!r} not supported!') from None - def does_token_look_special(self, token: str) -> bool: + def does_token_look_special(self, token: str | bytes) -> bool: + if isinstance(token, (bytes, bytearray)): + token_text = token.decode(encoding="utf-8") + elif isinstance(token, memoryview): + token_text = token.tobytes().decode(encoding="utf-8") + else: + token_text = token + # Some models mark some added tokens which ought to be control tokens as not special. # (e.g. command-r, command-r-plus, deepseek-coder, gemma{,-2}) - is_known_special = token in ( + seems_special = token_text in ( "", # deepseek-coder "", "<2mass>", "[@BOS@]", # gemma{,-2} ) - # TODO: should these be marked as UNUSED instead? - is_known_special = is_known_special or (token.startswith("")) # gemma{,-2} - return is_known_special or (token.startswith(("<|", "<|")) and token.endswith(("|>", "|>"))) + seems_special = seems_special or (token_text.startswith("<|") and token_text.endswith("|>")) + seems_special = seems_special or (token_text.startswith("<|") and token_text.endswith("|>")) # deepseek-coder + + # TODO: should these be marked as UNUSED instead? (maybe not) + seems_special = seems_special or (token_text.startswith("")) # gemma{,-2} + + return seems_special # used for GPT-2 BPE and WordPiece vocabs def get_vocab_base(self) -> tuple[list[str], list[int], str]: @@ -403,17 +414,18 @@ def get_vocab_base(self) -> tuple[list[str], list[int], str]: for i in range(vocab_size): if i not in reverse_vocab: tokens.append(f"[PAD{i}]") - toktypes.append(gguf.TokenType.USER_DEFINED) - elif reverse_vocab[i] in added_vocab: + toktypes.append(gguf.TokenType.UNUSED) + else: token: str = reverse_vocab[i] - tokens.append(token) - if tokenizer.added_tokens_decoder[i].special or self.does_token_look_special(token): - toktypes.append(gguf.TokenType.CONTROL) + if token in added_vocab: + if tokenizer.added_tokens_decoder[i].special or self.does_token_look_special(token): + toktypes.append(gguf.TokenType.CONTROL) + else: + token = token.replace(b"\xe2\x96\x81".decode("utf-8"), " ") # pre-normalize user-defined spaces + toktypes.append(gguf.TokenType.USER_DEFINED) else: - toktypes.append(gguf.TokenType.USER_DEFINED) - else: - tokens.append(reverse_vocab[i]) - toktypes.append(gguf.TokenType.NORMAL) + toktypes.append(gguf.TokenType.NORMAL) + tokens.append(token) return tokens, toktypes, tokpre @@ -572,7 +584,7 @@ def _set_vocab_qwen(self): for i in range(vocab_size): if i not in reverse_vocab: tokens.append(f"[PAD{i}]") - toktypes.append(gguf.TokenType.USER_DEFINED) + toktypes.append(gguf.TokenType.UNUSED) elif reverse_vocab[i] in added_vocab: tokens.append(reverse_vocab[i]) toktypes.append(gguf.TokenType.CONTROL) @@ -657,6 +669,25 @@ def _create_vocab_sentencepiece(self): scores[token_id] = -1000.0 toktypes[token_id] = SentencePieceTokenTypes.USER_DEFINED + tokenizer_config_file = self.dir_model / 'tokenizer_config.json' + if tokenizer_config_file.is_file(): + with open(tokenizer_config_file, "r", encoding="utf-8") as f: + tokenizer_config_json = json.load(f) + added_tokens_decoder = tokenizer_config_json.get("added_tokens_decoder", {}) + for token_id, token_data in added_tokens_decoder.items(): + token_id = int(token_id) + token: str = token_data["content"] + if toktypes[token_id] != SentencePieceTokenTypes.UNKNOWN: + assert tokens[token_id] == token.encode("utf-8") + if token_data.get("special") or self.does_token_look_special(token): + toktypes[token_id] = SentencePieceTokenTypes.CONTROL + else: + token = token.replace(b"\xe2\x96\x81".decode("utf-8"), " ") # pre-normalize user-defined spaces + toktypes[token_id] = SentencePieceTokenTypes.USER_DEFINED + + scores[token_id] = -1000.0 + tokens[token_id] = token.encode("utf-8") + if vocab_size > len(tokens): pad_count = vocab_size - len(tokens) logger.debug(f"Padding vocab with {pad_count} token(s) - [PAD1] through [PAD{pad_count}]") @@ -1280,7 +1311,7 @@ def set_vocab(self): if (self.dir_model / "tokenizer.json").is_file(): self._set_vocab_gpt2() else: - # StableLM 2 1.6B uses a vocab in a similar format to Qwen's vocab + # StableLM 2 1.6B used to have a vocab in a similar format to Qwen's vocab self._set_vocab_qwen() def set_gguf_parameters(self): @@ -1592,7 +1623,6 @@ def set_gguf_parameters(self): self.gguf_writer.add_rope_freq_base(attn_config["rope_theta"]) self.gguf_writer.add_clamp_kqv(attn_config["clip_qkv"]) - self.gguf_writer.add_file_type(self.ftype) self.gguf_writer.add_expert_count(ffn_config["moe_num_experts"]) self.gguf_writer.add_expert_used_count(ffn_config["moe_top_k"]) @@ -2412,19 +2442,7 @@ class Gemma2Model(Model): model_arch = gguf.MODEL_ARCH.GEMMA2 def set_vocab(self): - tokens, scores, toktypes = self._create_vocab_sentencepiece() - # hack: This is required so that we can properly use start/end-of-turn for chat template - for i in range(108): - # including , , - toktypes[i] = SentencePieceTokenTypes.CONTROL - self.gguf_writer.add_tokenizer_model("llama") - self.gguf_writer.add_tokenizer_pre("default") - self.gguf_writer.add_token_list(tokens) - self.gguf_writer.add_token_scores(scores) - self.gguf_writer.add_token_types(toktypes) - - special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens)) - special_vocab.add_to_gguf(self.gguf_writer) + self._set_vocab_sentencepiece() self.gguf_writer.add_add_space_prefix(False) @@ -3318,7 +3336,7 @@ def set_vocab(self): for i in range(vocab_size): if i not in reverse_vocab: tokens.append(f"[PAD{i}]") - toktypes.append(gguf.TokenType.USER_DEFINED) + toktypes.append(gguf.TokenType.UNUSED) elif reverse_vocab[i] in added_vocab: tokens.append(reverse_vocab[i]) if tokenizer.added_tokens_decoder[i].special: diff --git a/src/llama.cpp b/src/llama.cpp index 11147eb1159b9d..c30d0adfe00b43 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -5640,7 +5640,7 @@ static void llm_load_vocab( // build special tokens cache { for (llama_vocab::id id = 0; id < (llama_vocab::id)n_vocab; ++id) { - if (!(vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_NORMAL)) { + if (vocab.id_to_token[id].attr & (LLAMA_TOKEN_ATTR_CONTROL | LLAMA_TOKEN_ATTR_USER_DEFINED)) { vocab.cache_special_tokens.push_back(id); } } diff --git a/tests/test-tokenizer-random.py b/tests/test-tokenizer-random.py index c50a8ca32f6573..cdfc2b12c11789 100644 --- a/tests/test-tokenizer-random.py +++ b/tests/test-tokenizer-random.py @@ -20,7 +20,7 @@ from typing_extensions import Buffer import cffi -from transformers import AutoTokenizer +from transformers import AutoTokenizer, PreTrainedTokenizer logger = logging.getLogger("test-tokenizer-random") @@ -129,7 +129,7 @@ def decode(self, ids: list[int]) -> str: class TokenizerGroundtruth (Tokenizer): def __init__(self, dir_tokenizer: str): - self.model = AutoTokenizer.from_pretrained(dir_tokenizer) + self.model: PreTrainedTokenizer = AutoTokenizer.from_pretrained(dir_tokenizer) # guess BOS and EOS ids = self.encode("a") assert 1 <= len(ids) <= 3 @@ -143,7 +143,7 @@ def __init__(self, dir_tokenizer: str): self.vocab = list(sorted(self.vocab)) # tokens and lists self.special_tokens = list(self.model.all_special_tokens) - self.added_tokens = list(self.model.added_tokens_encoder) + self.added_tokens = self.model.batch_decode(self.model.added_tokens_encoder.values(), skip_special_tokens=False) self.bos_token = self.model.bos_token self.eos_token = self.model.eos_token @@ -458,8 +458,8 @@ def check_detokenizer(text: str, text1: str, text2: str) -> bool: i = find_first_mismatch(ids1, ids2) ids1 = list(ids1)[max(0, i - 2) : i + 5 + 1] ids2 = list(ids2)[max(0, i - 2) : i + 5 + 1] - logger.error(" Expected: " + str(ids1)) - logger.error(" Result: " + str(ids2)) + logger.error(" Expected: " + str(ids1) + f" {[tokenizer1.decode([id]) for id in ids1]}") + logger.error(" Result: " + str(ids2) + f" {[tokenizer2.decode([id]) for id in ids2]}") encode_errors += 1 logger.error(f" {encode_errors=}") if decode_errors < MAX_ERRORS and not check_detokenizer(text, text1, text2): From 31a1b0eeaa2c690f63772844fdac1ac24ed024c8 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Mon, 8 Jul 2024 16:34:39 -0400 Subject: [PATCH 06/11] llama : fix Viking pre-tokenizer regex The order was previously wrong, which caused errors in some tests. --- src/llama.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llama.cpp b/src/llama.cpp index c30d0adfe00b43..b652762d290006 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -15440,8 +15440,8 @@ struct llm_tokenizer_bpe { break; case LLAMA_VOCAB_PRE_TYPE_VIKING: regex_exprs = { - "\\p{N}", " ?[^(\\s|.,!?…。,、।۔،)]+", + "\\p{N}", }; break; default: From d6fe269ced93d45783a3b37c3cb20554264e5578 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Mon, 8 Jul 2024 18:13:16 -0400 Subject: [PATCH 07/11] llama : fix command-r detokenization --- src/llama.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/src/llama.cpp b/src/llama.cpp index b652762d290006..3509ff59948e8f 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -5407,6 +5407,7 @@ static void llm_load_vocab( } else if ( tokenizer_pre == "command-r") { vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_COMMAND_R; + vocab.tokenizer_clean_spaces = false; } else if ( tokenizer_pre == "qwen2") { vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_QWEN2; From d4df785868a7a638b20fdf9b9f3e34bc48cdcae3 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Mon, 8 Jul 2024 21:09:52 -0400 Subject: [PATCH 08/11] convert_hf : reduce usages of the UNKNOWN token type --- convert_hf_to_gguf.py | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index da6d2ba9e26529..b2bfb695b92aa9 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -634,7 +634,7 @@ def _create_vocab_sentencepiece(self): tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)] scores: list[float] = [-10000.0] * vocab_size - toktypes: list[int] = [SentencePieceTokenTypes.UNKNOWN] * vocab_size + toktypes: list[int] = [SentencePieceTokenTypes.UNUSED] * vocab_size for token_id in range(tokenizer.vocab_size()): piece = tokenizer.IdToPiece(token_id) @@ -677,7 +677,7 @@ def _create_vocab_sentencepiece(self): for token_id, token_data in added_tokens_decoder.items(): token_id = int(token_id) token: str = token_data["content"] - if toktypes[token_id] != SentencePieceTokenTypes.UNKNOWN: + if toktypes[token_id] != SentencePieceTokenTypes.UNUSED: assert tokens[token_id] == token.encode("utf-8") if token_data.get("special") or self.does_token_look_special(token): toktypes[token_id] = SentencePieceTokenTypes.CONTROL @@ -1916,7 +1916,7 @@ def set_vocab(self): tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)] scores: list[float] = [-10000.0] * vocab_size - toktypes: list[int] = [SentencePieceTokenTypes.UNKNOWN] * vocab_size + toktypes: list[int] = [SentencePieceTokenTypes.UNUSED] * vocab_size for token_id in range(tokenizer.vocab_size()): @@ -1961,7 +1961,7 @@ def set_vocab(self): for token_id, foken_data in added_tokens_decoder.items(): token_id = int(token_id) token = foken_data["content"].encode("utf-8") - if toktypes[token_id] != SentencePieceTokenTypes.UNKNOWN: + if toktypes[token_id] != SentencePieceTokenTypes.UNUSED: assert tokens[token_id] == token tokens[token_id] = token scores[token_id] = -1000.0 @@ -1977,7 +1977,7 @@ def set_vocab(self): for foken_data in added_tokens: token_id = int(foken_data["id"]) token = foken_data["content"].encode("utf-8") - if toktypes[token_id] != SentencePieceTokenTypes.UNKNOWN: + if toktypes[token_id] != SentencePieceTokenTypes.UNUSED: assert tokens[token_id] == token tokens[token_id] = token scores[token_id] = -1000.0 @@ -2766,7 +2766,7 @@ def set_vocab(self): tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)] scores: list[float] = [-10000.0] * vocab_size - toktypes: list[int] = [SentencePieceTokenTypes.UNKNOWN] * vocab_size + toktypes: list[int] = [SentencePieceTokenTypes.UNUSED] * vocab_size for token_id in range(tokenizer.vocab_size()): @@ -3021,7 +3021,7 @@ def set_vocab(self): tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)] scores: list[float] = [-10000.0] * vocab_size - toktypes: list[int] = [SentencePieceTokenTypes.UNKNOWN] * vocab_size + toktypes: list[int] = [SentencePieceTokenTypes.UNUSED] * vocab_size for token_id in range(tokenizer.vocab_size()): piece = tokenizer.IdToPiece(token_id) @@ -3239,15 +3239,14 @@ def set_vocab_chatglm3(self): if len(piece) != 0 and token_id < tokenizer.tokenizer.sp_model.vocab_size(): score = tokenizer.tokenizer.sp_model.get_score(token_id) - if len(piece) == 0: - text = f"[PAD{token_id}]".encode("utf-8") - if token_id >= tokenizer.tokenizer.sp_model.vocab_size(): if piece in special_tokens: - # show special tokens in prompt - toktype = SentencePieceTokenTypes.USER_DEFINED + toktype = SentencePieceTokenTypes.CONTROL + elif len(piece) == 0: + text = f"[PAD{token_id}]".encode("utf-8") + toktype = SentencePieceTokenTypes.UNUSED else: - toktype = SentencePieceTokenTypes.UNKNOWN + toktype = SentencePieceTokenTypes.USER_DEFINED tokens.append(text) scores.append(score) toktypes.append(toktype) From 98edea60bcd67458d70138c13c428414cd8ec63a Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Mon, 8 Jul 2024 21:23:19 -0400 Subject: [PATCH 09/11] llama : add UNKNOWN tokens in the special tokens cache --- src/llama.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index 3509ff59948e8f..a2e8d62fc1c19e 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -5641,7 +5641,7 @@ static void llm_load_vocab( // build special tokens cache { for (llama_vocab::id id = 0; id < (llama_vocab::id)n_vocab; ++id) { - if (vocab.id_to_token[id].attr & (LLAMA_TOKEN_ATTR_CONTROL | LLAMA_TOKEN_ATTR_USER_DEFINED)) { + if (vocab.id_to_token[id].attr & (LLAMA_TOKEN_ATTR_CONTROL | LLAMA_TOKEN_ATTR_USER_DEFINED | LLAMA_TOKEN_ATTR_UNKNOWN)) { vocab.cache_special_tokens.push_back(id); } } @@ -16168,8 +16168,8 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list< const auto & data = vocab.id_to_token[special_id]; const auto & special_token = data.text; - if (!parse_special && (data.attr & LLAMA_TOKEN_ATTR_CONTROL)) { - // Only ignore control tokens when parse_special == false + if (!parse_special && (data.attr & (LLAMA_TOKEN_ATTR_CONTROL | LLAMA_TOKEN_ATTR_UNKNOWN))) { + // Ignore control and unknown tokens when parse_special == false continue; // User-defined tokens are still pre-tokenized before everything else // ref: https://github.com/huggingface/tokenizers/blob/fdd26ba9a3f0c133427aab0423888cbde91362d7/tokenizers/src/tokenizer/mod.rs#L726 From dd07a123b79f9bd9e8a4ba0447427b3083e9347a Mon Sep 17 00:00:00 2001 From: Clint Herron Date: Wed, 10 Jul 2024 12:35:18 -0400 Subject: [PATCH 10/11] Name Migration: Build the deprecation-warning 'main' binary every time (#8404) * Modify the deprecation-warning 'main' binary to build every time, instead of only when a legacy binary is present. This is to help users of tutorials and other instruction sets from knowing what to do when the 'main' binary is missing and they are trying to follow instructions. * Adjusting 'server' name-deprecation binary to build all the time, similar to the 'main' legacy name binary. --- Makefile | 24 ++++++++---------------- 1 file changed, 8 insertions(+), 16 deletions(-) diff --git a/Makefile b/Makefile index 68197fef800199..668b38b99c3129 100644 --- a/Makefile +++ b/Makefile @@ -1513,15 +1513,17 @@ llama-q8dot: pocs/vdot/q8dot.cpp ggml/src/ggml.o \ # Mark legacy binary targets as .PHONY so that they are always checked. .PHONY: main quantize perplexity embedding server finetune +# NOTE: We currently will always build the deprecation-warning `main` and `server` binaries to help users migrate. +# Eventually we will want to remove these target from building all the time. main: examples/deprecation-warning/deprecation-warning.cpp -ifneq (,$(wildcard main)) $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) $(CXX) $(CXXFLAGS) $(filter-out $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) - @echo "#########" - @echo "WARNING: The 'main' binary is deprecated. Please use 'llama-cli' instead." - @echo " Remove the 'main' binary to remove this warning." - @echo "#########" -endif + @echo "NOTICE: The 'main' binary is deprecated. Please use 'llama-cli' instead." + +server: examples/deprecation-warning/deprecation-warning.cpp + $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) + $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) + @echo "NOTICE: The 'server' binary is deprecated. Please use 'llama-server' instead." quantize: examples/deprecation-warning/deprecation-warning.cpp ifneq (,$(wildcard quantize)) @@ -1553,16 +1555,6 @@ ifneq (,$(wildcard embedding)) @echo "#########" endif -server: examples/deprecation-warning/deprecation-warning.cpp -ifneq (,$(wildcard server)) - $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) - $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) - @echo "#########" - @echo "WARNING: The 'server' binary is deprecated. Please use 'llama-server' instead." - @echo " Remove the 'server' binary to remove this warning." - @echo "#########" -endif - finetune: examples/deprecation-warning/deprecation-warning.cpp ifneq (,$(wildcard finetune)) $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) From 1caa20fc7a4bd0eac1cc26e5c7262c3dadeaf952 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Wed, 10 Jul 2024 17:33:04 -0400 Subject: [PATCH 11/11] convert_hf : reduce usages of UNKNOWN for InternLM2 This makes the changes from #8321 more consistent with the other changes made here. --- convert_hf_to_gguf.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 0236166b32921f..c15c126eb52df8 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -2189,7 +2189,7 @@ def set_vocab(self): toktype = SentencePieceTokenTypes.BYTE # take care of ununsed raw token if piece.startswith('[UNUSED'): - toktype = SentencePieceTokenTypes.UNKNOWN + toktype = SentencePieceTokenTypes.UNUSED tokens.append(text) scores.append(score) @@ -2219,7 +2219,7 @@ def set_vocab(self): if token == chat_eos_token: chat_eos_token_id = token_id token = token.encode("utf-8") - if toktypes[token_id] != SentencePieceTokenTypes.UNKNOWN: + if toktypes[token_id] != SentencePieceTokenTypes.UNUSED: assert(tokens[token_id] == token) tokens[token_id] = token scores[token_id] = -1000.0 @@ -2238,7 +2238,7 @@ def set_vocab(self): if token == chat_eos_token: chat_eos_token_id = token_id token = token.encode("utf-8") - if toktypes[token_id] != SentencePieceTokenTypes.UNKNOWN: + if toktypes[token_id] != SentencePieceTokenTypes.UNUSED: assert(tokens[token_id] == token) tokens[token_id] = token scores[token_id] = -1000.0