From a2a7f71604ebc930ed6cbc594eb0594811e24bc8 Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Thu, 28 Mar 2024 21:58:40 +0900 Subject: [PATCH] =?UTF-8?q?[=20`TokenizationLlama`]=20fix=20the=20way=20we?= =?UTF-8?q?=20convert=20tokens=20to=20strings=20to=20keep=20leading=20spac?= =?UTF-8?q?es=20=F0=9F=9A=A8=20breaking=20fix=20(#29453)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * nit * update test and fix test * fixup --- src/transformers/models/llama/tokenization_llama.py | 2 ++ tests/models/llama/test_tokenization_llama.py | 13 +++++++++++++ 2 files changed, 15 insertions(+) diff --git a/src/transformers/models/llama/tokenization_llama.py b/src/transformers/models/llama/tokenization_llama.py index 2f68d6c1e951..744e2e3fe2c2 100644 --- a/src/transformers/models/llama/tokenization_llama.py +++ b/src/transformers/models/llama/tokenization_llama.py @@ -295,6 +295,8 @@ def convert_tokens_to_string(self, tokens): prev_is_special = True current_sub_tokens = [] else: + if prev_is_special and i == 1 and self.add_prefix_space and not token.startswith(SPIECE_UNDERLINE): + out_string += " " current_sub_tokens.append(token) prev_is_special = False out_string += self.sp_model.decode(current_sub_tokens) diff --git a/tests/models/llama/test_tokenization_llama.py b/tests/models/llama/test_tokenization_llama.py index 0cee3347c408..5a0bcea48af1 100644 --- a/tests/models/llama/test_tokenization_llama.py +++ b/tests/models/llama/test_tokenization_llama.py @@ -581,6 +581,19 @@ def test_special_token_special_word(self): decoded_tokens = tokenizer.decode(input_ids) self.assertEqual(decoded_tokens, " Hello how") + # Let's make sure the space is preserved + input_ids = tokenizer.encode("hello", add_special_tokens=True) + self.assertEqual(input_ids, [1, 22172]) + tokens = tokenizer.tokenize("hello") + self.assertEqual(tokens, ["▁hello"]) + decoded_tokens = tokenizer.decode(input_ids) + self.assertEqual(decoded_tokens, " hello") + + input_ids = tokenizer.encode("hello", add_special_tokens=False) + self.assertEqual(input_ids, [22172]) + decoded_tokens = tokenizer.decode(input_ids) + self.assertEqual(decoded_tokens, "hello") + def test_some_edge_cases(self): tokenizer = LlamaTokenizer.from_pretrained("huggyllama/llama-7b", legacy=False)