From 68220feaf8b8c8f8b9d1690fd975bba445d5d4b6 Mon Sep 17 00:00:00 2001 From: jaime-m-p <> Date: Tue, 25 Jun 2024 17:36:44 +0200 Subject: [PATCH] Update bruteforce test --- tests/test-tokenizer-random.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/tests/test-tokenizer-random.py b/tests/test-tokenizer-random.py index 5d4f6e8765806..d0ed3558edd13 100644 --- a/tests/test-tokenizer-random.py +++ b/tests/test-tokenizer-random.py @@ -235,6 +235,8 @@ def generator_custom_text_edge_cases() -> Iterator[str]: 'å', # mpt '\U000ac517', # utf-8 encode error, falcon '\U000522f4', # utf-8 encode error, starcoder + "abcd", + " abcd", ] @@ -334,7 +336,7 @@ def _valid(cpt): return False # if cpt == 0x2029: # deepseek-llm # return False - if unicodedata.category(chr(cpt)) in ( "Cn", "Cs", "Co" ): # undefined, surrogates, private + if unicodedata.category(chr(cpt)) in ("Cn", "Cs", "Co"): # undefined, surrogates, private return False return True @@ -426,6 +428,7 @@ def check_detokenizer(text: str, text1: str, text2: str) -> bool: t_start = time.perf_counter() encode_errors = 0 decode_errors = 0 + MAX_ERRORS = 10 logger.info("%s: %s" % (generator.__name__, "ini")) for text in generator: @@ -444,7 +447,7 @@ def check_detokenizer(text: str, text1: str, text2: str) -> bool: t_encode2 += t2 - t1 t_decode1 += t3 - t2 t_decode2 += t4 - t3 - if ids1 != ids2: + if encode_errors < MAX_ERRORS and ids1 != ids2: i = find_first_mismatch(ids1, ids2) ids1 = list(ids1)[max(0, i - 2) : i + 5 + 1] ids2 = list(ids2)[max(0, i - 2) : i + 5 + 1] @@ -452,7 +455,7 @@ def check_detokenizer(text: str, text1: str, text2: str) -> bool: logger.error(" Result: " + str(ids2)) encode_errors += 1 logger.error(f" {encode_errors=}") - if not check_detokenizer(text, text1, text2): + if decode_errors < MAX_ERRORS and not check_detokenizer(text, text1, text2): i = find_first_mismatch(text1, text2) text1 = list(text1[max(0, i - 2) : i + 5 + 1]) text2 = list(text2[max(0, i - 2) : i + 5 + 1]) @@ -460,7 +463,7 @@ def check_detokenizer(text: str, text1: str, text2: str) -> bool: logger.error(" Result: " + " ".join(hex(ord(x)) for x in text2)) decode_errors += 1 logger.error(f" {decode_errors=}") - if encode_errors >= 10 or decode_errors >= 10: + if encode_errors >= MAX_ERRORS and decode_errors >= MAX_ERRORS: logger.error(f" EXIT: {encode_errors=} {decode_errors=}") # raise Exception() break