Skip to content

Commit

Permalink
[ TokenizationLlama] fix the way we convert tokens to strings to ke…
Browse files Browse the repository at this point in the history
…ep leading spaces 🚨 breaking fix (#29453)

* nit

* update test and fix test

* fixup
  • Loading branch information
ArthurZucker authored Mar 28, 2024
1 parent e677479 commit a2a7f71
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/transformers/models/llama/tokenization_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,8 @@ def convert_tokens_to_string(self, tokens):
prev_is_special = True
current_sub_tokens = []
else:
if prev_is_special and i == 1 and self.add_prefix_space and not token.startswith(SPIECE_UNDERLINE):
out_string += " "
current_sub_tokens.append(token)
prev_is_special = False
out_string += self.sp_model.decode(current_sub_tokens)
Expand Down
13 changes: 13 additions & 0 deletions tests/models/llama/test_tokenization_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,19 @@ def test_special_token_special_word(self):
decoded_tokens = tokenizer.decode(input_ids)
self.assertEqual(decoded_tokens, " <s> Hello<s> how")

# Let's make sure the space is preserved
input_ids = tokenizer.encode("hello", add_special_tokens=True)
self.assertEqual(input_ids, [1, 22172])
tokens = tokenizer.tokenize("hello")
self.assertEqual(tokens, ["▁hello"])
decoded_tokens = tokenizer.decode(input_ids)
self.assertEqual(decoded_tokens, "<s> hello")

input_ids = tokenizer.encode("hello", add_special_tokens=False)
self.assertEqual(input_ids, [22172])
decoded_tokens = tokenizer.decode(input_ids)
self.assertEqual(decoded_tokens, "hello")

def test_some_edge_cases(self):
tokenizer = LlamaTokenizer.from_pretrained("huggyllama/llama-7b", legacy=False)

Expand Down

0 comments on commit a2a7f71

Please sign in to comment.