Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Detokenization discrepancy with Llama3.1 #35175

Closed
2 of 4 tasks
AbrahamSanders opened this issue Dec 9, 2024 · 5 comments
Closed
2 of 4 tasks

Detokenization discrepancy with Llama3.1 #35175

AbrahamSanders opened this issue Dec 9, 2024 · 5 comments
Labels

Comments

@AbrahamSanders
Copy link

AbrahamSanders commented Dec 9, 2024

System Info

  • transformers version: 4.47.0
  • Platform: Linux-5.15.0-91-generic-x86_64-with-glibc2.35
  • Python version: 3.12.7
  • Huggingface_hub version: 0.26.5
  • Safetensors version: 0.4.5
  • Accelerate version: not installed
  • Accelerate config: not found
  • PyTorch version (GPU?): not installed (NA)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?: N/A

Who can help?

@ArthurZucker @itazap

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Spaces are being stripped from space-prefixed token Ġ' when followed by a common abbreviation (e.g., n't, 'm, 's, 've), even when not appropriate to do so. This is being caused because clean_up_tokenization_spaces is True by default for the Llama 3.1 tokenizer.

from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")

original = " plunged the long 'sword' into"
input_ids = tokenizer.encode(original, add_special_tokens=False)
tokens = tokenizer.convert_ids_to_tokens(input_ids)
decoded = tokenizer.decode(input_ids)
decoded2 = tokenizer.decode(input_ids, clean_up_tokenization_spaces=False)

print("token ids:                ", input_ids)
print("tokens:                   ", tokens)
print("original:                ", original)
print("decoded (default):       ", decoded)
print("decoded (clean_up=False):", decoded2)

Produces

token ids:                 [75803, 279, 1317, 364, 80138, 6, 1139]
tokens:                    ['Ġplunged', 'Ġthe', 'Ġlong', "Ġ'", 'sword', "'", 'Ġinto']
original:                  plunged the long 'sword' into
decoded (default):         plunged the long'sword' into
decoded (clean_up=False):  plunged the long 'sword' into

Expected behavior

I would expect the original string to match the decoded string in all cases unless it actually contains "traditional" tokenization spacing (e.g., it 's vs it's). Perhaps a good approach could be to modify the clean_up_tokenization function to only apply this rule when the common abbreviation is followed immediately by another space.

@ArthurZucker
Copy link
Collaborator

As you mentioned, clean_up_tokenization_spaces is set to True. It should be set to False !

@denadai2
Copy link

denadai2 commented Dec 10, 2024

Is it a problem to be fixed in the code in master or should we set it manually everytime?

@AbrahamSanders
Copy link
Author

@ArthurZucker yes, same question as @denadai2.

It seems counterintuitive for the tokenizer's default to be True, since most applications won't be using pre-tokenized texts. I think a good default would be False in the tokenizer config, and then anyone working with pre-tokenized text can set it to True manually in their decode() call. Alternatively, clean_up_tokenization could be modified as I suggested previously. Any thoughts?

@ArthurZucker
Copy link
Collaborator

It IS counter intuitive, but we can't easily break stuff in transformers this has been True for models online for example.
#31938 made is False by default if unset, so Llama should not have it to False. But here the model has "clean_up_tokenization_spaces": true, in the config. Let's open a pr to change this!

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants