-
Notifications
You must be signed in to change notification settings - Fork 10.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
llama : fix pre-tokenization of non-special added tokens #8228
Changes from 7 commits
db2ffd5
ac0f33c
d5d30b2
6b961e3
56df1fc
6e351e0
f9d42c5
31a1b0e
d6fe269
d4df785
98edea6
afa6119
1caa20f
59ce853
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -373,6 +373,29 @@ def from_model_architecture(cls, arch: str) -> type[Model]: | |
except KeyError: | ||
raise NotImplementedError(f'Architecture {arch!r} not supported!') from None | ||
|
||
def does_token_look_special(self, token: str | bytes) -> bool: | ||
if isinstance(token, (bytes, bytearray)): | ||
token_text = token.decode(encoding="utf-8") | ||
elif isinstance(token, memoryview): | ||
token_text = token.tobytes().decode(encoding="utf-8") | ||
else: | ||
token_text = token | ||
|
||
# Some models mark some added tokens which ought to be control tokens as not special. | ||
# (e.g. command-r, command-r-plus, deepseek-coder, gemma{,-2}) | ||
seems_special = token_text in ( | ||
"<pad>", # deepseek-coder | ||
"<mask>", "<2mass>", "[@BOS@]", # gemma{,-2} | ||
) | ||
|
||
seems_special = seems_special or (token_text.startswith("<|") and token_text.endswith("|>")) | ||
seems_special = seems_special or (token_text.startswith("<|") and token_text.endswith("|>")) # deepseek-coder | ||
|
||
# TODO: should these be marked as UNUSED instead? (maybe not) | ||
seems_special = seems_special or (token_text.startswith("<unused") and token_text.endswith(">")) # gemma{,-2} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should things like this be defined in the conversion script under the specific model to avoid accidental false hits? if, for some weird reason, a model comes around with a non-special token that starts with <|, would be annoying to avoid that maybe does_token_look_special should take in 2 lists: 1 list of strings of known special tokens, and a list of tuples of starts/ends with tokens So for gemma2, we'd call it with:
and then here we'd have:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh i realize I misread the structure of the code, hmmm.. still not impossible but would have to be passed at a higher level There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
This only affects added tokens either from I did not yet notice any conflict in the Also, this is a method of |
||
|
||
return seems_special | ||
|
||
# used for GPT-2 BPE and WordPiece vocabs | ||
def get_vocab_base(self) -> tuple[list[str], list[int], str]: | ||
tokens: list[str] = [] | ||
|
@@ -391,16 +414,18 @@ def get_vocab_base(self) -> tuple[list[str], list[int], str]: | |
for i in range(vocab_size): | ||
if i not in reverse_vocab: | ||
tokens.append(f"[PAD{i}]") | ||
toktypes.append(gguf.TokenType.USER_DEFINED) | ||
elif reverse_vocab[i] in added_vocab: | ||
tokens.append(reverse_vocab[i]) | ||
if tokenizer.added_tokens_decoder[i].special: | ||
toktypes.append(gguf.TokenType.CONTROL) | ||
else: | ||
toktypes.append(gguf.TokenType.USER_DEFINED) | ||
toktypes.append(gguf.TokenType.UNUSED) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. UNUSED solves a lot of vocab and added token problems. |
||
else: | ||
tokens.append(reverse_vocab[i]) | ||
toktypes.append(gguf.TokenType.NORMAL) | ||
token: str = reverse_vocab[i] | ||
if token in added_vocab: | ||
if tokenizer.added_tokens_decoder[i].special or self.does_token_look_special(token): | ||
toktypes.append(gguf.TokenType.CONTROL) | ||
else: | ||
token = token.replace(b"\xe2\x96\x81".decode("utf-8"), " ") # pre-normalize user-defined spaces | ||
toktypes.append(gguf.TokenType.USER_DEFINED) | ||
else: | ||
toktypes.append(gguf.TokenType.NORMAL) | ||
tokens.append(token) | ||
|
||
return tokens, toktypes, tokpre | ||
|
||
|
@@ -559,7 +584,7 @@ def _set_vocab_qwen(self): | |
for i in range(vocab_size): | ||
if i not in reverse_vocab: | ||
tokens.append(f"[PAD{i}]") | ||
toktypes.append(gguf.TokenType.USER_DEFINED) | ||
toktypes.append(gguf.TokenType.UNUSED) | ||
Comment on lines
586
to
+587
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Padding tokens are set as |
||
elif reverse_vocab[i] in added_vocab: | ||
tokens.append(reverse_vocab[i]) | ||
toktypes.append(gguf.TokenType.CONTROL) | ||
|
@@ -644,6 +669,25 @@ def _create_vocab_sentencepiece(self): | |
scores[token_id] = -1000.0 | ||
toktypes[token_id] = SentencePieceTokenTypes.USER_DEFINED | ||
|
||
tokenizer_config_file = self.dir_model / 'tokenizer_config.json' | ||
if tokenizer_config_file.is_file(): | ||
with open(tokenizer_config_file, "r", encoding="utf-8") as f: | ||
tokenizer_config_json = json.load(f) | ||
added_tokens_decoder = tokenizer_config_json.get("added_tokens_decoder", {}) | ||
for token_id, token_data in added_tokens_decoder.items(): | ||
token_id = int(token_id) | ||
token: str = token_data["content"] | ||
if toktypes[token_id] != SentencePieceTokenTypes.UNKNOWN: | ||
compilade marked this conversation as resolved.
Show resolved
Hide resolved
|
||
assert tokens[token_id] == token.encode("utf-8") | ||
if token_data.get("special") or self.does_token_look_special(token): | ||
toktypes[token_id] = SentencePieceTokenTypes.CONTROL | ||
else: | ||
token = token.replace(b"\xe2\x96\x81".decode("utf-8"), " ") # pre-normalize user-defined spaces | ||
toktypes[token_id] = SentencePieceTokenTypes.USER_DEFINED | ||
|
||
scores[token_id] = -1000.0 | ||
tokens[token_id] = token.encode("utf-8") | ||
|
||
if vocab_size > len(tokens): | ||
pad_count = vocab_size - len(tokens) | ||
logger.debug(f"Padding vocab with {pad_count} token(s) - [PAD1] through [PAD{pad_count}]") | ||
|
@@ -1267,7 +1311,7 @@ def set_vocab(self): | |
if (self.dir_model / "tokenizer.json").is_file(): | ||
self._set_vocab_gpt2() | ||
else: | ||
# StableLM 2 1.6B uses a vocab in a similar format to Qwen's vocab | ||
# StableLM 2 1.6B used to have a vocab in a similar format to Qwen's vocab | ||
self._set_vocab_qwen() | ||
|
||
def set_gguf_parameters(self): | ||
|
@@ -1579,7 +1623,6 @@ def set_gguf_parameters(self): | |
self.gguf_writer.add_rope_freq_base(attn_config["rope_theta"]) | ||
|
||
self.gguf_writer.add_clamp_kqv(attn_config["clip_qkv"]) | ||
self.gguf_writer.add_file_type(self.ftype) | ||
|
||
self.gguf_writer.add_expert_count(ffn_config["moe_num_experts"]) | ||
self.gguf_writer.add_expert_used_count(ffn_config["moe_top_k"]) | ||
|
@@ -2399,19 +2442,7 @@ class Gemma2Model(Model): | |
model_arch = gguf.MODEL_ARCH.GEMMA2 | ||
|
||
def set_vocab(self): | ||
tokens, scores, toktypes = self._create_vocab_sentencepiece() | ||
# hack: This is required so that we can properly use start/end-of-turn for chat template | ||
for i in range(108): | ||
# including <unusedX>, <start_of_turn>, <end_of_turn> | ||
toktypes[i] = SentencePieceTokenTypes.CONTROL | ||
self.gguf_writer.add_tokenizer_model("llama") | ||
self.gguf_writer.add_tokenizer_pre("default") | ||
self.gguf_writer.add_token_list(tokens) | ||
self.gguf_writer.add_token_scores(scores) | ||
self.gguf_writer.add_token_types(toktypes) | ||
|
||
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens)) | ||
special_vocab.add_to_gguf(self.gguf_writer) | ||
self._set_vocab_sentencepiece() | ||
Comment on lines
-2438
to
+2481
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The "hack" from #8244 is no longer required because the control tokens are now identified with |
||
|
||
self.gguf_writer.add_add_space_prefix(False) | ||
|
||
|
@@ -3305,7 +3336,7 @@ def set_vocab(self): | |
for i in range(vocab_size): | ||
if i not in reverse_vocab: | ||
tokens.append(f"[PAD{i}]") | ||
toktypes.append(gguf.TokenType.USER_DEFINED) | ||
toktypes.append(gguf.TokenType.UNUSED) | ||
elif reverse_vocab[i] in added_vocab: | ||
tokens.append(reverse_vocab[i]) | ||
if tokenizer.added_tokens_decoder[i].special: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5640,7 +5640,7 @@ static void llm_load_vocab( | |
// build special tokens cache | ||
{ | ||
for (llama_vocab::id id = 0; id < (llama_vocab::id)n_vocab; ++id) { | ||
if (!(vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_NORMAL)) { | ||
if (vocab.id_to_token[id].attr & (LLAMA_TOKEN_ATTR_CONTROL | LLAMA_TOKEN_ATTR_USER_DEFINED)) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's necessary to limit the special token cache to But maybe it's a bad idea to exclude other types of tokens, like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree with I think is corrrect to drop UNUSED tokens but we need parse UNKNOWN token. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should UNKNOWN tokens only be parsed when I'm not sure if UNKNOWN tokens should be specially parsed at all. I would tend toward only parsing them with I'm trying to figure out where UNKNOWN tokens are used and if it's useful to specially parse them. But this might differ from HF's There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I've fixed this in 98edea6. UNKNOWN tokens are parsed when |
||
vocab.cache_special_tokens.push_back(id); | ||
} | ||
} | ||
|
@@ -15404,17 +15404,6 @@ struct llm_tokenizer_bpe { | |
"[0-9][0-9][0-9]", | ||
}; | ||
break; | ||
case LLAMA_VOCAB_PRE_TYPE_MPT: | ||
// TODO: MPT pre-tokenization regexes are unknown | ||
// the following are close, but not exact. run the following: | ||
// ./bin/test-tokenizer-0 ../models/ggml-vocab-mpt.gguf | ||
GGML_ASSERT("MPT pre-tokenization regexes are unknown - fixes needed"); | ||
regex_exprs = { | ||
"\\s?\\p{L}+", | ||
"\\s?\\p{P}+", | ||
"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)", | ||
}; | ||
break; | ||
case LLAMA_VOCAB_PRE_TYPE_STARCODER: | ||
case LLAMA_VOCAB_PRE_TYPE_REFACT: | ||
case LLAMA_VOCAB_PRE_TYPE_COMMAND_R: | ||
|
@@ -15424,6 +15413,7 @@ struct llm_tokenizer_bpe { | |
}; | ||
break; | ||
case LLAMA_VOCAB_PRE_TYPE_GPT2: | ||
case LLAMA_VOCAB_PRE_TYPE_MPT: | ||
case LLAMA_VOCAB_PRE_TYPE_OLMO: | ||
case LLAMA_VOCAB_PRE_TYPE_JAIS: | ||
regex_exprs = { | ||
|
@@ -16171,12 +16161,20 @@ struct fragment_buffer_variant { | |
|
||
// #define PRETOKENIZERDEBUG | ||
|
||
static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<fragment_buffer_variant> & buffer) { | ||
static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<fragment_buffer_variant> & buffer, bool parse_special) { | ||
compilade marked this conversation as resolved.
Show resolved
Hide resolved
|
||
// for each special token | ||
for (const llama_vocab::id special_id : vocab.cache_special_tokens) { | ||
const auto & data = vocab.id_to_token[special_id]; | ||
const auto & special_token = data.text; | ||
|
||
if (!parse_special && (data.attr & LLAMA_TOKEN_ATTR_CONTROL)) { | ||
// Only ignore control tokens when parse_special == false | ||
continue; | ||
// User-defined tokens are still pre-tokenized before everything else | ||
// ref: https://github.com/huggingface/tokenizers/blob/fdd26ba9a3f0c133427aab0423888cbde91362d7/tokenizers/src/tokenizer/mod.rs#L726 | ||
// This is mostly relevant for neox-style tokenizers (mpt, olmo, stablelm, etc.) | ||
} | ||
|
||
// for each text fragment | ||
std::forward_list<fragment_buffer_variant>::iterator it = buffer.begin(); | ||
while (it != buffer.end()) { | ||
|
@@ -16289,7 +16287,7 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & | |
|
||
if (!raw_text.empty()) { | ||
fragment_buffer.emplace_front(raw_text, 0, raw_text.length()); | ||
if (parse_special) tokenizer_st_partition(vocab, fragment_buffer); | ||
tokenizer_st_partition(vocab, fragment_buffer, parse_special); | ||
} | ||
|
||
switch (vocab.type) { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,7 +20,7 @@ | |
from typing_extensions import Buffer | ||
|
||
import cffi | ||
from transformers import AutoTokenizer | ||
from transformers import AutoTokenizer, PreTrainedTokenizer | ||
|
||
|
||
logger = logging.getLogger("test-tokenizer-random") | ||
|
@@ -129,7 +129,7 @@ def decode(self, ids: list[int]) -> str: | |
class TokenizerGroundtruth (Tokenizer): | ||
|
||
def __init__(self, dir_tokenizer: str): | ||
self.model = AutoTokenizer.from_pretrained(dir_tokenizer) | ||
self.model: PreTrainedTokenizer = AutoTokenizer.from_pretrained(dir_tokenizer) | ||
# guess BOS and EOS | ||
ids = self.encode("a") | ||
assert 1 <= len(ids) <= 3 | ||
|
@@ -143,7 +143,7 @@ def __init__(self, dir_tokenizer: str): | |
self.vocab = list(sorted(self.vocab)) | ||
# tokens and lists | ||
self.special_tokens = list(self.model.all_special_tokens) | ||
self.added_tokens = list(self.model.added_tokens_encoder) | ||
self.added_tokens = self.model.batch_decode(self.model.added_tokens_encoder.values(), skip_special_tokens=False) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "PEP 8: E221 multiple spaces before operator" (just FYI - looks good to me) |
||
self.bos_token = self.model.bos_token | ||
self.eos_token = self.model.eos_token | ||
|
||
|
@@ -458,8 +458,8 @@ def check_detokenizer(text: str, text1: str, text2: str) -> bool: | |
i = find_first_mismatch(ids1, ids2) | ||
ids1 = list(ids1)[max(0, i - 2) : i + 5 + 1] | ||
ids2 = list(ids2)[max(0, i - 2) : i + 5 + 1] | ||
logger.error(" Expected: " + str(ids1)) | ||
logger.error(" Result: " + str(ids2)) | ||
logger.error(" Expected: " + str(ids1) + f" {[tokenizer1.decode([id]) for id in ids1]}") | ||
compilade marked this conversation as resolved.
Show resolved
Hide resolved
|
||
logger.error(" Result: " + str(ids2) + f" {[tokenizer2.decode([id]) for id in ids2]}") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "Shadows built-in name 'id'" There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
encode_errors += 1 | ||
logger.error(f" {encode_errors=}") | ||
if decode_errors < MAX_ERRORS and not check_detokenizer(text, text1, text2): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"Method 'does_token_look_special' may be 'static'"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I prefer not to make it a
@staticmethod
, to allow overriding it in the subclasses ofModel
if needed.