-
Notifications
You must be signed in to change notification settings - Fork 10.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Tokenizer fixes #8379
Draft
jaime-m-p
wants to merge
9
commits into
ggerganov:master
Choose a base branch
from
jaime-m-p:tokenizer-fixes
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Tokenizer fixes #8379
Changes from 5 commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
e8b3955
Fix pyparse problems: gcc inline functions
9307c3f
Test l/r-strip for more than 4 spaces
a943b42
Improve mismatch range localization
dec64ef
Compare vocabs
c184db7
Options to mange token text decoding errors:
3eb1900
Skip literal UNUSED token checks
c4956e4
update test: fix special and added token lists
9b8e05b
Merge commit 'f4444d99' into tokenizer-fixes
3db5058
Merge branch 'master' into tokenizer-fixes
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -36,7 +36,7 @@ def __init__(self, path_llama_h: str = None, path_includes: list[str] = [], path | |
self.lib.llama_backend_init() | ||
|
||
def _load_libllama_cffi(self, path_llama_h: str, path_includes: list[str], path_libllama: str): | ||
cmd = ["gcc", "-E", "-P", "-D__restrict=", "-D__attribute__(x)=", "-D__asm__(x)="] | ||
cmd = ["gcc", "-O0", "-fno-inline", "-E", "-P", "-D__restrict=", "-D__attribute__(x)=", "-D__asm__(x)="] | ||
cmd += ["-I" + path for path in path_includes] + [path_llama_h] | ||
res = subprocess.run(cmd, stdout=subprocess.PIPE) | ||
assert (res.returncode == 0) | ||
|
@@ -112,9 +112,25 @@ def detokenize(self, ids: list[int], remove_special: bool = False, unparse_speci | |
num = self.lib.llama_detokenize(self.model, self.token_ids, len(ids), self.text_buff, len(self.text_buff), remove_special, unparse_special) | ||
return str(self.ffi.buffer(self.text_buff, num), encoding="utf-8", errors="replace") # replace errors with '\uFFFD' | ||
|
||
def get_vocab(self, detokenize=False) -> list[str]: | ||
vocab: list[str] = [] | ||
num_tokens = self.lib.llama_n_vocab(self.model) | ||
for id in range(num_tokens): | ||
if detokenize: | ||
text = self.detokenize([id], remove_special=False, unparse_special=True) | ||
else: | ||
text = self.lib.llama_token_get_text(self.model, id) | ||
text = self.ffi.string(text) | ||
text = str(text, encoding="utf-8", errors="replace") # replace errors with '\uFFFD' | ||
vocab.append(text) | ||
return vocab | ||
|
||
|
||
class Tokenizer: | ||
|
||
def get_vocab(self, detokenize=False) -> list[str]: | ||
raise NotImplementedError | ||
|
||
def encode(self, text: str) -> list[int]: | ||
raise NotImplementedError | ||
|
||
|
@@ -125,7 +141,7 @@ def decode(self, ids: list[int]) -> str: | |
class TokenizerGroundtruth (Tokenizer): | ||
|
||
def __init__(self, dir_tokenizer: str): | ||
self.model = AutoTokenizer.from_pretrained(dir_tokenizer) | ||
self.model = AutoTokenizer.from_pretrained(dir_tokenizer, trust_remote_code=False) | ||
# guess BOS and EOS | ||
ids = self.encode("a") | ||
assert 1 <= len(ids) <= 3 | ||
|
@@ -134,15 +150,24 @@ def __init__(self, dir_tokenizer: str): | |
self.add_bos_token = getattr(self.model, "add_bos_token", add_bos_token) | ||
self.add_eos_token = getattr(self.model, "add_eos_token", add_eos_token) | ||
# build vocab | ||
tokens = list(self.model.get_vocab().values()) | ||
self.vocab = self.model.batch_decode(tokens, skip_special_tokens=True) | ||
self.vocab = list(sorted(self.vocab)) | ||
self.vocab = self.get_vocab(detokenize=True) | ||
# tokens and lists | ||
self.special_tokens = list(self.model.all_special_tokens) | ||
self.added_tokens = list(self.model.added_tokens_encoder) | ||
self.bos_token = self.model.bos_token | ||
self.eos_token = self.model.eos_token | ||
|
||
def get_vocab(self, detokenize=False) -> list[str]: | ||
max_token_id = max(self.model.get_vocab().values()) | ||
if detokenize: | ||
ids = list(range(max_token_id + 1)) | ||
vocab = self.model.batch_decode(ids, skip_special_tokens=False) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you think this should be used in the convert script(s) instead of directly getting the strings from EDIT: this might be a bad idea, since the tokenizer merges won't directly match with the strings from the vocab if that's done |
||
else: | ||
vocab = [None] * (max_token_id + 1) | ||
for text, id in self.model.get_vocab().items(): | ||
vocab[id] = text | ||
return vocab | ||
|
||
def encode(self, text: str) -> list[int]: | ||
return self.model.encode(text, add_special_tokens=True) | ||
|
||
|
@@ -159,6 +184,9 @@ def __init__(self, vocab_file: str): | |
self.libllama = LibLlama() | ||
self.model = LibLlamaModel(self.libllama, vocab_file, mparams=dict(vocab_only=True), cparams=dict(n_ctx=4096)) | ||
|
||
def get_vocab(self, detokenize=False) -> list[str]: | ||
return self.model.get_vocab(detokenize) | ||
|
||
def encode(self, text: str) -> list[int]: | ||
return self.model.tokenize(text, add_special=True, parse_special=True) | ||
|
||
|
@@ -273,7 +301,7 @@ def generator_apostrophe() -> Iterator[str]: | |
|
||
|
||
def generator_added_lr_strip(tokenizer: TokenizerGroundtruth) -> Iterator[str]: | ||
WHITESPACES = ["", " ", " ", "\n", "\r\n", "\n\n", "\t", "\t\t"] | ||
WHITESPACES = ["", " ", " ", "\n", "\r\n", "\n\n", "\t", "\t\t", " "] | ||
all_tokens = list(sorted(set(tokenizer.special_tokens + tokenizer.added_tokens))) | ||
for token in all_tokens: | ||
for lstrip in WHITESPACES: | ||
|
@@ -404,14 +432,6 @@ def generator_random_vocab_words(tokenizer: TokenizerGroundtruth, iterations=100 | |
|
||
def compare_tokenizers(tokenizer1: TokenizerGroundtruth, tokenizer2: TokenizerLlamaCpp, generator: Iterator[str]): | ||
|
||
def find_first_mismatch(ids1: list[int], ids2: list[int]): | ||
for i, (a, b) in enumerate(zip(ids1, ids2)): | ||
if a != b: | ||
return i | ||
if len(ids1) == len(ids2): | ||
return -1 | ||
return min(len(ids1), len(ids2)) | ||
|
||
def check_detokenizer(text: str, text1: str, text2: str) -> bool: | ||
if text1 == text2: # equal to TokenizerGroundtruth? | ||
return True | ||
|
@@ -431,6 +451,7 @@ def check_detokenizer(text: str, text1: str, text2: str) -> bool: | |
t_start = time.perf_counter() | ||
encode_errors = 0 | ||
decode_errors = 0 | ||
total_tests = 0 | ||
MAX_ERRORS = 10 | ||
|
||
logger.info("%s: %s" % (generator.__name__, "ini")) | ||
|
@@ -450,21 +471,44 @@ def check_detokenizer(text: str, text1: str, text2: str) -> bool: | |
t_encode2 += t2 - t1 | ||
t_decode1 += t3 - t2 | ||
t_decode2 += t4 - t3 | ||
if encode_errors < MAX_ERRORS and ids1 != ids2: | ||
i = find_first_mismatch(ids1, ids2) | ||
ids1 = list(ids1)[max(0, i - 2) : i + 5 + 1] | ||
ids2 = list(ids2)[max(0, i - 2) : i + 5 + 1] | ||
# compare | ||
encode_ok = ids1 == ids2 | ||
decode_ok = check_detokenizer(text, text1, text2) | ||
encode_errors += not encode_ok | ||
decode_errors += not decode_ok | ||
total_tests += 1 | ||
if (encode_errors < MAX_ERRORS and not encode_ok) or (decode_errors < MAX_ERRORS and not decode_ok): | ||
def _compare(text: str): | ||
ids1 = tokenizer1.encode(text) | ||
ids2 = tokenizer2.encode(text) | ||
text1 = tokenizer1.decode(ids1) | ||
text2 = tokenizer2.decode(ids1) | ||
encode_ok = ids1 == ids2 | ||
decode_ok = check_detokenizer(text, text1, text2) | ||
ok = encode_ok and decode_ok | ||
return ok, ids1, ids2, text1, text2 | ||
a, b = 0, len(text) | ||
for step in [64, 32, 16, 8, 4, 2, 1]: | ||
while a < b: | ||
t = max(a, b - step) | ||
if _compare(text[a : t])[0]: | ||
break | ||
b = t | ||
for step in [64, 32, 16, 8, 4, 2, 1]: | ||
while a < b: | ||
t = min(a + step, b) | ||
if _compare(text[t : b])[0]: | ||
break | ||
a = t | ||
ok, ids1, ids2, text1, text2 = _compare(text[a : b]) | ||
assert a <= b and not ok | ||
logger.error(" Text:" + repr(text[a : b])) | ||
logger.error(" " + " ".join(repr(x) + ":" + hex(ord(x)) for x in text[a : b])) | ||
logger.error(" Expected: " + str(ids1)) | ||
logger.error(" Result: " + str(ids2)) | ||
encode_errors += 1 | ||
logger.error(" Expected: " + " ".join(repr(x) + ":" + hex(ord(x)) for x in text1)) | ||
logger.error(" Result: " + " ".join(repr(x) + ":" + hex(ord(x)) for x in text2)) | ||
logger.error(f" {encode_errors=}") | ||
if decode_errors < MAX_ERRORS and not check_detokenizer(text, text1, text2): | ||
i = find_first_mismatch(text1, text2) | ||
text1 = list(text1[max(0, i - 2) : i + 5 + 1]) | ||
text2 = list(text2[max(0, i - 2) : i + 5 + 1]) | ||
logger.error(" Expected: " + " ".join(hex(ord(x)) for x in text1)) | ||
logger.error(" Result: " + " ".join(hex(ord(x)) for x in text2)) | ||
decode_errors += 1 | ||
logger.error(f" {decode_errors=}") | ||
if encode_errors >= MAX_ERRORS and decode_errors >= MAX_ERRORS: | ||
logger.error(f" EXIT: {encode_errors=} {decode_errors=}") | ||
|
@@ -475,6 +519,34 @@ def check_detokenizer(text: str, text1: str, text2: str) -> bool: | |
logger.info(f"{generator.__name__}: end, {t_encode1=:.3f} {t_encode2=:.3f} {t_decode1=:.3f} {t_decode2=:.3f} {t_total=:.3f}") | ||
|
||
|
||
def compare_vocabs(tokenizer1: TokenizerGroundtruth, tokenizer2: TokenizerLlamaCpp): | ||
|
||
MAX_PRINT_ERRORS = 10 | ||
|
||
logger.info("compare_vocabs: ini") | ||
|
||
t_start = time.perf_counter() | ||
|
||
for detokenize in (False, True): | ||
vocab1 = tokenizer1.get_vocab(detokenize) | ||
vocab2 = tokenizer2.get_vocab(detokenize) | ||
if vocab1 != vocab2: | ||
num_errors = 0 | ||
for i in range(max(len(vocab1), len(vocab2))): | ||
text1 = vocab1[i] if i < len(vocab1) else "" | ||
text2 = vocab2[i] if i < len(vocab2) else "" | ||
is_unused = text1.startswith("[UNUSED_TOKEN_") # AutoTokenizer adds more unused tokens than SentencePiece ? | ||
if text1 != text2 and is_unused and text2: | ||
num_errors += 1 | ||
if num_errors < MAX_PRINT_ERRORS: | ||
logger.error(f" {detokenize=} id={i} expected={repr(text1)} result={repr(text2)}") | ||
if num_errors: | ||
logger.error(f" {num_errors=}") | ||
|
||
t_total = time.perf_counter() - t_start | ||
logger.info(f"compare_vocabs: end, {t_total=:.3f}") | ||
|
||
|
||
def main(argv: list[str] = None): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("vocab_file", help="path to vocab 'gguf' file") | ||
|
@@ -488,13 +560,16 @@ def main(argv: list[str] = None): | |
tokenizer1 = TokenizerGroundtruth(args.dir_tokenizer) | ||
tokenizer2 = TokenizerLlamaCpp(args.vocab_file) | ||
|
||
# compare_tokenizers(tokenizer1, tokenizer2, generator_custom_text()) | ||
# compare_tokenizers(tokenizer1, tokenizer2, generator_custom_text_edge_cases()) | ||
compare_tokenizers(tokenizer1, tokenizer2, generator_ascii_lr_strip()) | ||
compare_tokenizers(tokenizer1, tokenizer2, generator_apostrophe()) | ||
compare_tokenizers(tokenizer1, tokenizer2, generator_unicodes()) | ||
compare_tokenizers(tokenizer1, tokenizer2, generator_vocab_words(tokenizer1)) | ||
compare_tokenizers(tokenizer1, tokenizer2, generator_added_lr_strip(tokenizer1)) | ||
compare_vocabs(tokenizer1, tokenizer2) | ||
|
||
compare_tokenizers(tokenizer1, tokenizer2, generator_custom_text()) | ||
compare_tokenizers(tokenizer1, tokenizer2, generator_custom_text_edge_cases()) | ||
# compare_tokenizers(tokenizer1, tokenizer2, generator_representative(tokenizer1)) | ||
# compare_tokenizers(tokenizer1, tokenizer2, generator_ascii_lr_strip()) | ||
# compare_tokenizers(tokenizer1, tokenizer2, generator_apostrophe()) | ||
# compare_tokenizers(tokenizer1, tokenizer2, generator_unicodes()) | ||
# compare_tokenizers(tokenizer1, tokenizer2, generator_vocab_words(tokenizer1)) | ||
# compare_tokenizers(tokenizer1, tokenizer2, generator_added_lr_strip(tokenizer1)) | ||
# compare_tokenizers(tokenizer1, tokenizer2, generator_random_added_tokens(tokenizer1, 10_000)) | ||
# compare_tokenizers(tokenizer1, tokenizer2, generator_random_chars(10_000)) | ||
# compare_tokenizers(tokenizer1, tokenizer2, generator_random_unicodes(10_000)) | ||
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think
-fno-inline
is redundant with-O0
. And-O0
alone works, while-fno-inline
alone doesn't.Anyway, I suggest resolving the conflict with
master
.