Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama : fix pre-tokenization of non-special added tokens #8228

Merged
merged 14 commits into from
Jul 14, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 71 additions & 41 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,29 @@ def from_model_architecture(cls, arch: str) -> type[Model]:
except KeyError:
raise NotImplementedError(f'Architecture {arch!r} not supported!') from None

def does_token_look_special(self, token: str | bytes) -> bool:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Method 'does_token_look_special' may be 'static'"

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer not to make it a @staticmethod, to allow overriding it in the subclasses of Model if needed.

if isinstance(token, (bytes, bytearray)):
token_text = token.decode(encoding="utf-8")
elif isinstance(token, memoryview):
token_text = token.tobytes().decode(encoding="utf-8")
else:
token_text = token

# Some models mark some added tokens which ought to be control tokens as not special.
# (e.g. command-r, command-r-plus, deepseek-coder, gemma{,-2})
seems_special = token_text in (
"<pad>", # deepseek-coder
"<mask>", "<2mass>", "[@BOS@]", # gemma{,-2}
)

seems_special = seems_special or (token_text.startswith("<|") and token_text.endswith("|>"))
seems_special = seems_special or (token_text.startswith("<|") and token_text.endswith("|>")) # deepseek-coder

# TODO: should these be marked as UNUSED instead? (maybe not)
seems_special = seems_special or (token_text.startswith("<unused") and token_text.endswith(">")) # gemma{,-2}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should things like this be defined in the conversion script under the specific model to avoid accidental false hits? if, for some weird reason, a model comes around with a non-special token that starts with <|, would be annoying to avoid that

maybe does_token_look_special should take in 2 lists: 1 list of strings of known special tokens, and a list of tuples of starts/ends with tokens

So for gemma2, we'd call it with:

special_tokens = ["<mask>", "<2mass>", "[@BOS@]"]
special_tags = [("<unused", "|>")]

self.does_token_look_special(token, special_tokens, special_tags)

and then here we'd have:

seems_special = token_text in special_tokens
for start_tag, end_tag in special_tags:
  seems_special = seems_special or (token_text.startswith(start_tag) and token_text.endswith(end_tag))

return seems_special

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh i realize I misread the structure of the code, hmmm.. still not impossible but would have to be passed at a higher level

Copy link
Collaborator Author

@compilade compilade Jul 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if, for some weird reason, a model comes around with a non-special token that starts with <|, would be annoying to avoid that

This only affects added tokens either from added_tokens in tokenizer.json or from added_tokens_decoder in tokenizer_config.json, so it does not affect normal tokens starting with <| in any way. Not all tokens of the vocab are checked with this, only the ones part of added_tokens (which are treated specially by HF tokenizers too anyway). And added tokens starting with <| and ending with |> are arguably always control tokens; this was added pretty much because some model makers wrongly marked those as non-special (notably, <|User|>, <|Assistant|> and <|EOT|> in deepseek-coder are supposedly non-special. Same with <|START_OF_TURN_TOKEN|> and <|END_OF_TURN_TOKEN|> for command-r).

I did not yet notice any conflict in the added_tokens to justify making model-specific checks instead of always checking for all known "special-but-arent-marked-special" tokens.

Also, this is a method of Model, so it technically can be overridden by subclasses should there ever be a model with conflicting added_tokens.


return seems_special

# used for GPT-2 BPE and WordPiece vocabs
def get_vocab_base(self) -> tuple[list[str], list[int], str]:
tokens: list[str] = []
Expand All @@ -391,16 +414,18 @@ def get_vocab_base(self) -> tuple[list[str], list[int], str]:
for i in range(vocab_size):
if i not in reverse_vocab:
tokens.append(f"[PAD{i}]")
toktypes.append(gguf.TokenType.USER_DEFINED)
elif reverse_vocab[i] in added_vocab:
tokens.append(reverse_vocab[i])
if tokenizer.added_tokens_decoder[i].special:
toktypes.append(gguf.TokenType.CONTROL)
else:
toktypes.append(gguf.TokenType.USER_DEFINED)
toktypes.append(gguf.TokenType.UNUSED)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

UNUSED solves a lot of vocab and added token problems.
Also reducing the number of added tokens improves speed.

else:
tokens.append(reverse_vocab[i])
toktypes.append(gguf.TokenType.NORMAL)
token: str = reverse_vocab[i]
if token in added_vocab:
if tokenizer.added_tokens_decoder[i].special or self.does_token_look_special(token):
toktypes.append(gguf.TokenType.CONTROL)
else:
token = token.replace(b"\xe2\x96\x81".decode("utf-8"), " ") # pre-normalize user-defined spaces
toktypes.append(gguf.TokenType.USER_DEFINED)
else:
toktypes.append(gguf.TokenType.NORMAL)
tokens.append(token)

return tokens, toktypes, tokpre

Expand Down Expand Up @@ -559,7 +584,7 @@ def _set_vocab_qwen(self):
for i in range(vocab_size):
if i not in reverse_vocab:
tokens.append(f"[PAD{i}]")
toktypes.append(gguf.TokenType.USER_DEFINED)
toktypes.append(gguf.TokenType.UNUSED)
Comment on lines 586 to +587
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Padding tokens are set as UNUSED to reflect how it was already done in _set_vocab_sentencepiece, and also to avoid wrongly (pre-)tokenizing strings which happen to correspond to a padding token. (since USER_DEFINED tokens are now always pre-tokenized specially)

elif reverse_vocab[i] in added_vocab:
tokens.append(reverse_vocab[i])
toktypes.append(gguf.TokenType.CONTROL)
Expand Down Expand Up @@ -609,7 +634,7 @@ def _create_vocab_sentencepiece(self):

tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)]
scores: list[float] = [-10000.0] * vocab_size
toktypes: list[int] = [SentencePieceTokenTypes.UNKNOWN] * vocab_size
toktypes: list[int] = [SentencePieceTokenTypes.UNUSED] * vocab_size

for token_id in range(tokenizer.vocab_size()):
piece = tokenizer.IdToPiece(token_id)
Expand Down Expand Up @@ -644,6 +669,25 @@ def _create_vocab_sentencepiece(self):
scores[token_id] = -1000.0
toktypes[token_id] = SentencePieceTokenTypes.USER_DEFINED

tokenizer_config_file = self.dir_model / 'tokenizer_config.json'
if tokenizer_config_file.is_file():
with open(tokenizer_config_file, "r", encoding="utf-8") as f:
tokenizer_config_json = json.load(f)
added_tokens_decoder = tokenizer_config_json.get("added_tokens_decoder", {})
for token_id, token_data in added_tokens_decoder.items():
token_id = int(token_id)
token: str = token_data["content"]
if toktypes[token_id] != SentencePieceTokenTypes.UNUSED:
assert tokens[token_id] == token.encode("utf-8")
if token_data.get("special") or self.does_token_look_special(token):
toktypes[token_id] = SentencePieceTokenTypes.CONTROL
else:
token = token.replace(b"\xe2\x96\x81".decode("utf-8"), " ") # pre-normalize user-defined spaces
toktypes[token_id] = SentencePieceTokenTypes.USER_DEFINED

scores[token_id] = -1000.0
tokens[token_id] = token.encode("utf-8")

if vocab_size > len(tokens):
pad_count = vocab_size - len(tokens)
logger.debug(f"Padding vocab with {pad_count} token(s) - [PAD1] through [PAD{pad_count}]")
Expand Down Expand Up @@ -1267,7 +1311,7 @@ def set_vocab(self):
if (self.dir_model / "tokenizer.json").is_file():
self._set_vocab_gpt2()
else:
# StableLM 2 1.6B uses a vocab in a similar format to Qwen's vocab
# StableLM 2 1.6B used to have a vocab in a similar format to Qwen's vocab
self._set_vocab_qwen()

def set_gguf_parameters(self):
Expand Down Expand Up @@ -1579,7 +1623,6 @@ def set_gguf_parameters(self):
self.gguf_writer.add_rope_freq_base(attn_config["rope_theta"])

self.gguf_writer.add_clamp_kqv(attn_config["clip_qkv"])
self.gguf_writer.add_file_type(self.ftype)

self.gguf_writer.add_expert_count(ffn_config["moe_num_experts"])
self.gguf_writer.add_expert_used_count(ffn_config["moe_top_k"])
Expand Down Expand Up @@ -1873,7 +1916,7 @@ def set_vocab(self):

tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)]
scores: list[float] = [-10000.0] * vocab_size
toktypes: list[int] = [SentencePieceTokenTypes.UNKNOWN] * vocab_size
toktypes: list[int] = [SentencePieceTokenTypes.UNUSED] * vocab_size

for token_id in range(tokenizer.vocab_size()):

Expand Down Expand Up @@ -1918,7 +1961,7 @@ def set_vocab(self):
for token_id, foken_data in added_tokens_decoder.items():
token_id = int(token_id)
token = foken_data["content"].encode("utf-8")
if toktypes[token_id] != SentencePieceTokenTypes.UNKNOWN:
if toktypes[token_id] != SentencePieceTokenTypes.UNUSED:
assert tokens[token_id] == token
tokens[token_id] = token
scores[token_id] = -1000.0
Expand All @@ -1934,7 +1977,7 @@ def set_vocab(self):
for foken_data in added_tokens:
token_id = int(foken_data["id"])
token = foken_data["content"].encode("utf-8")
if toktypes[token_id] != SentencePieceTokenTypes.UNKNOWN:
if toktypes[token_id] != SentencePieceTokenTypes.UNUSED:
assert tokens[token_id] == token
tokens[token_id] = token
scores[token_id] = -1000.0
Expand Down Expand Up @@ -2146,7 +2189,7 @@ def set_vocab(self):
toktype = SentencePieceTokenTypes.BYTE
# take care of ununsed raw token
if piece.startswith('[UNUSED'):
toktype = SentencePieceTokenTypes.UNKNOWN
toktype = SentencePieceTokenTypes.UNUSED

tokens.append(text)
scores.append(score)
Expand Down Expand Up @@ -2176,7 +2219,7 @@ def set_vocab(self):
if token == chat_eos_token:
chat_eos_token_id = token_id
token = token.encode("utf-8")
if toktypes[token_id] != SentencePieceTokenTypes.UNKNOWN:
if toktypes[token_id] != SentencePieceTokenTypes.UNUSED:
assert(tokens[token_id] == token)
tokens[token_id] = token
scores[token_id] = -1000.0
Expand All @@ -2195,7 +2238,7 @@ def set_vocab(self):
if token == chat_eos_token:
chat_eos_token_id = token_id
token = token.encode("utf-8")
if toktypes[token_id] != SentencePieceTokenTypes.UNKNOWN:
if toktypes[token_id] != SentencePieceTokenTypes.UNUSED:
assert(tokens[token_id] == token)
tokens[token_id] = token
scores[token_id] = -1000.0
Expand Down Expand Up @@ -2435,19 +2478,7 @@ class Gemma2Model(Model):
model_arch = gguf.MODEL_ARCH.GEMMA2

def set_vocab(self):
tokens, scores, toktypes = self._create_vocab_sentencepiece()
# hack: This is required so that we can properly use start/end-of-turn for chat template
for i in range(108):
# including <unusedX>, <start_of_turn>, <end_of_turn>
toktypes[i] = SentencePieceTokenTypes.CONTROL
self.gguf_writer.add_tokenizer_model("llama")
self.gguf_writer.add_tokenizer_pre("default")
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_scores(scores)
self.gguf_writer.add_token_types(toktypes)

special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
special_vocab.add_to_gguf(self.gguf_writer)
self._set_vocab_sentencepiece()
Comment on lines -2438 to +2481
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The "hack" from #8244 is no longer required because the control tokens are now identified with tokenizer_config.json.


self.gguf_writer.add_add_space_prefix(False)

Expand Down Expand Up @@ -2771,7 +2802,7 @@ def set_vocab(self):

tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)]
scores: list[float] = [-10000.0] * vocab_size
toktypes: list[int] = [SentencePieceTokenTypes.UNKNOWN] * vocab_size
toktypes: list[int] = [SentencePieceTokenTypes.UNUSED] * vocab_size

for token_id in range(tokenizer.vocab_size()):

Expand Down Expand Up @@ -3026,7 +3057,7 @@ def set_vocab(self):

tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)]
scores: list[float] = [-10000.0] * vocab_size
toktypes: list[int] = [SentencePieceTokenTypes.UNKNOWN] * vocab_size
toktypes: list[int] = [SentencePieceTokenTypes.UNUSED] * vocab_size

for token_id in range(tokenizer.vocab_size()):
piece = tokenizer.IdToPiece(token_id)
Expand Down Expand Up @@ -3244,15 +3275,14 @@ def set_vocab_chatglm3(self):
if len(piece) != 0 and token_id < tokenizer.tokenizer.sp_model.vocab_size():
score = tokenizer.tokenizer.sp_model.get_score(token_id)

if len(piece) == 0:
text = f"[PAD{token_id}]".encode("utf-8")

if token_id >= tokenizer.tokenizer.sp_model.vocab_size():
if piece in special_tokens:
# show special tokens in prompt
toktype = SentencePieceTokenTypes.USER_DEFINED
toktype = SentencePieceTokenTypes.CONTROL
elif len(piece) == 0:
text = f"[PAD{token_id}]".encode("utf-8")
toktype = SentencePieceTokenTypes.UNUSED
else:
toktype = SentencePieceTokenTypes.UNKNOWN
toktype = SentencePieceTokenTypes.USER_DEFINED
tokens.append(text)
scores.append(score)
toktypes.append(toktype)
Expand Down Expand Up @@ -3341,7 +3371,7 @@ def set_vocab(self):
for i in range(vocab_size):
if i not in reverse_vocab:
tokens.append(f"[PAD{i}]")
toktypes.append(gguf.TokenType.USER_DEFINED)
toktypes.append(gguf.TokenType.UNUSED)
elif reverse_vocab[i] in added_vocab:
tokens.append(reverse_vocab[i])
if tokenizer.added_tokens_decoder[i].special:
Expand Down
29 changes: 14 additions & 15 deletions src/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5419,6 +5419,7 @@ static void llm_load_vocab(
} else if (
tokenizer_pre == "command-r") {
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_COMMAND_R;
vocab.tokenizer_clean_spaces = false;
} else if (
tokenizer_pre == "qwen2") {
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_QWEN2;
Expand Down Expand Up @@ -5652,7 +5653,7 @@ static void llm_load_vocab(
// build special tokens cache
{
for (llama_vocab::id id = 0; id < (llama_vocab::id)n_vocab; ++id) {
if (!(vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_NORMAL)) {
if (vocab.id_to_token[id].attr & (LLAMA_TOKEN_ATTR_CONTROL | LLAMA_TOKEN_ATTR_USER_DEFINED | LLAMA_TOKEN_ATTR_UNKNOWN)) {
vocab.cache_special_tokens.push_back(id);
}
}
Expand Down Expand Up @@ -15418,17 +15419,6 @@ struct llm_tokenizer_bpe {
"[0-9][0-9][0-9]",
};
break;
case LLAMA_VOCAB_PRE_TYPE_MPT:
// TODO: MPT pre-tokenization regexes are unknown
// the following are close, but not exact. run the following:
// ./bin/test-tokenizer-0 ../models/ggml-vocab-mpt.gguf
GGML_ASSERT("MPT pre-tokenization regexes are unknown - fixes needed");
regex_exprs = {
"\\s?\\p{L}+",
"\\s?\\p{P}+",
"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
};
break;
case LLAMA_VOCAB_PRE_TYPE_STARCODER:
case LLAMA_VOCAB_PRE_TYPE_REFACT:
case LLAMA_VOCAB_PRE_TYPE_COMMAND_R:
Expand All @@ -15438,6 +15428,7 @@ struct llm_tokenizer_bpe {
};
break;
case LLAMA_VOCAB_PRE_TYPE_GPT2:
case LLAMA_VOCAB_PRE_TYPE_MPT:
case LLAMA_VOCAB_PRE_TYPE_OLMO:
case LLAMA_VOCAB_PRE_TYPE_JAIS:
regex_exprs = {
Expand All @@ -15464,8 +15455,8 @@ struct llm_tokenizer_bpe {
break;
case LLAMA_VOCAB_PRE_TYPE_VIKING:
regex_exprs = {
"\\p{N}",
" ?[^(\\s|.,!?…。,、।۔،)]+",
"\\p{N}",
};
break;
default:
Expand Down Expand Up @@ -16185,12 +16176,20 @@ struct fragment_buffer_variant {

// #define PRETOKENIZERDEBUG

static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<fragment_buffer_variant> & buffer) {
static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<fragment_buffer_variant> & buffer, bool parse_special) {
compilade marked this conversation as resolved.
Show resolved Hide resolved
// for each special token
for (const llama_vocab::id special_id : vocab.cache_special_tokens) {
const auto & data = vocab.id_to_token[special_id];
const auto & special_token = data.text;

if (!parse_special && (data.attr & (LLAMA_TOKEN_ATTR_CONTROL | LLAMA_TOKEN_ATTR_UNKNOWN))) {
// Ignore control and unknown tokens when parse_special == false
continue;
// User-defined tokens are still pre-tokenized before everything else
// ref: https://github.com/huggingface/tokenizers/blob/fdd26ba9a3f0c133427aab0423888cbde91362d7/tokenizers/src/tokenizer/mod.rs#L726
// This is mostly relevant for neox-style tokenizers (mpt, olmo, stablelm, etc.)
}

// for each text fragment
std::forward_list<fragment_buffer_variant>::iterator it = buffer.begin();
while (it != buffer.end()) {
Expand Down Expand Up @@ -16303,7 +16302,7 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &

if (!raw_text.empty()) {
fragment_buffer.emplace_front(raw_text, 0, raw_text.length());
if (parse_special) tokenizer_st_partition(vocab, fragment_buffer);
tokenizer_st_partition(vocab, fragment_buffer, parse_special);
}

switch (vocab.type) {
Expand Down
4 changes: 2 additions & 2 deletions tests/test-tokenizer-0.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ int main(int argc, char **argv) {
const bool add_special = false;

for (const auto & test_kv : k_tests) {
const std::vector<llama_token> res = llama_tokenize(ctx, test_kv.first, add_special, true);
const std::vector<llama_token> res = llama_tokenize(ctx, test_kv.first, add_special, false);

printf("\n");
printf("src: '%s'\n", test_kv.first.c_str());
Expand Down Expand Up @@ -253,7 +253,7 @@ int main(int argc, char **argv) {
{
const auto t_start = ggml_time_us();

res = llama_tokenize(ctx, text, add_special, true);
res = llama_tokenize(ctx, text, add_special, false);

const auto t_end = ggml_time_us();

Expand Down
10 changes: 5 additions & 5 deletions tests/test-tokenizer-random.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from typing_extensions import Buffer

import cffi
from transformers import AutoTokenizer
from transformers import AutoTokenizer, PreTrainedTokenizer


logger = logging.getLogger("test-tokenizer-random")
Expand Down Expand Up @@ -129,7 +129,7 @@ def decode(self, ids: list[int]) -> str:
class TokenizerGroundtruth (Tokenizer):

def __init__(self, dir_tokenizer: str):
self.model = AutoTokenizer.from_pretrained(dir_tokenizer)
self.model: PreTrainedTokenizer = AutoTokenizer.from_pretrained(dir_tokenizer)
# guess BOS and EOS
ids = self.encode("a")
assert 1 <= len(ids) <= 3
Expand All @@ -143,7 +143,7 @@ def __init__(self, dir_tokenizer: str):
self.vocab = list(sorted(self.vocab))
# tokens and lists
self.special_tokens = list(self.model.all_special_tokens)
self.added_tokens = list(self.model.added_tokens_encoder)
self.added_tokens = self.model.batch_decode(self.model.added_tokens_encoder.values(), skip_special_tokens=False)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"PEP 8: E221 multiple spaces before operator" (just FYI - looks good to me)
"PEP 8: E501 line too long (122 > 120 characters)"

self.bos_token = self.model.bos_token
self.eos_token = self.model.eos_token

Expand Down Expand Up @@ -458,8 +458,8 @@ def check_detokenizer(text: str, text1: str, text2: str) -> bool:
i = find_first_mismatch(ids1, ids2)
ids1 = list(ids1)[max(0, i - 2) : i + 5 + 1]
ids2 = list(ids2)[max(0, i - 2) : i + 5 + 1]
logger.error(" Expected: " + str(ids1))
logger.error(" Result: " + str(ids2))
logger.error(" Expected: " + str(ids1) + f" {[tokenizer1.decode([id]) for id in ids1]}")
compilade marked this conversation as resolved.
Show resolved Hide resolved
logger.error(" Result: " + str(ids2) + f" {[tokenizer2.decode([id]) for id in ids2]}")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Shadows built-in name 'id'"

Copy link
Collaborator Author

@compilade compilade Jul 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've reverted these two lines in 59ce853, even though they were useful for debugging, because this is a test script anyway, and because these lines also will unnecessarily conflict with #8379

encode_errors += 1
logger.error(f" {encode_errors=}")
if decode_errors < MAX_ERRORS and not check_detokenizer(text, text1, text2):
Expand Down
Loading