Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

token healing impl #29081

Closed
wants to merge 9 commits into from
Closed

token healing impl #29081

wants to merge 9 commits into from

Conversation

ahmed-moubtahij
Copy link
Contributor

@ahmed-moubtahij ahmed-moubtahij commented Feb 18, 2024

What does this PR do?

Token healing rectifies the token boundary bias in greedy tokenization. It does this by trimming and regrowing the prompt to better align with the model's tokenizer, thus enhancing generation quality. The improvement is clearest with completion models.

Token boundary bias is a silent performance killer that doesn't seem very well known. It has clear impact on completion quality.

A more thorough explanation of the problem: The Art of Prompt Design: Prompt Boundaries and Token Healing | by Scott Lundberg.

Motivation

Given a completion prompt with a partial url ending with :, the model might have seen the expected completion :// as a single token in training. However, the prompt's tail token : tells it that the next token is not //, and so it generates a wrong completion. Such errors compound in auto-regressive language models.

Fixes #28346

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@gante
Copy link
Member

gante commented Feb 19, 2024

CI is failing due to an automatic update in the pytest package, we are tracking it. Will let you know when it is sorted -- it will need a rebase

@ahmed-moubtahij
Copy link
Contributor Author

CI is failing due to an automatic update in the pytest package, we are tracking it. Will let you know when it is sorted -- it will need a rebase

Thanks for the follow-up!

@gante
Copy link
Member

gante commented Feb 19, 2024

@Ayenem main is fixed, rebasing should make CI green except if there are PR-specific issues :)

@ahmed-moubtahij
Copy link
Contributor Author

ahmed-moubtahij commented Feb 20, 2024

image
In case it's relevant, here are (some) listed remotes with git branch -r:

  origin/HEAD -> origin/main
  origin/heal_tokens
  origin/main
  origin/token_healing
  upstream/'delete-delete-doc'
  upstream/BritneyMuller-housekeeping-patch
  upstream/_dummy_fix_weight_only_usage
  upstream/_dummy_fix_weight_only_usage_2
  upstream/add-chat-glm
  upstream/add-deci-lm
  upstream/add-encode-special-tokens
  upstream/add-flash-decoding
  upstream/add-mamba
  upstream/add-prefix-space
  upstream/add-quantization-workflow

@gante
Copy link
Member

gante commented Feb 26, 2024

(@Ayenem we're trying to fix the merge conflicts for you, and we're experimenting with a few GH permissions on our side. You may see a few test commits 🤗 )

@gante
Copy link
Member

gante commented Feb 28, 2024

Now rebased after #29320 was merged, which was causing the last set of errors seen here. If everything went well, we should see a green CI here 🤞

@gante gante requested a review from ArthurZucker February 28, 2024 12:54
@gante
Copy link
Member

gante commented Feb 28, 2024

@Ayenem FYI, I've reverted the tokenizer input to your original suggestion (tokenizer passed to generate), after a discussion I had with @Rocketknight1. That way, the input is standardized and matches another incoming PR (#28932) 🤗

@ahmed-moubtahij
Copy link
Contributor Author

@Ayenem FYI, I've reverted the tokenizer input to your original suggestion (tokenizer passed to generate), after a discussion I had with @Rocketknight1. That way, the input is standardized and matches another incoming PR (#28932) 🤗

It does feel better to offload the tokenizer choice and loading to the caller. Thanks again for following up on this 🙏

@ahmed-moubtahij
Copy link
Contributor Author

CI is green! It was possible :')

@gante
Copy link
Member

gante commented Mar 5, 2024

ping @ArthurZucker :)

@ArthurZucker
Copy link
Collaborator

Sorry for the late review on it!

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a few nits, mostly safely import and protect the function as the new dependency is optional / should be optional. Potentially use our own trie?

@@ -22,6 +22,7 @@

import torch
import torch.distributed as dist
from pygtrie import CharTrie
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if this is an optinal dependency we need to protect the import

Comment on lines +1822 to +1823
"""
if tokenizer is None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""
if tokenizer is None:
"""
requires_backends(self, ["pygtrie"])

we also need to make sure this function errors out correctly if used

"argument of `generate`."
)
bos_id, pad_id = tokenizer.bos_token_id, tokenizer.pad_token_id
vocab_trie = CharTrie(tokenizer.get_vocab())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW we have https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_utils.py#L52

which could be used for this? Would remove the dependency? (It's might be additional work as well)

input_ids = torch.where(input_ids == bos_id, pad_id, input_ids)

tail_ids = input_ids[:, -1].tolist()
space_tok = tokenizer.tokenize(" ")[0]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not 100% sure this will always do what you want, specifically for tokenizer that add a prefix token you could get [▁▁]

@LeonardoEmili
Copy link
Contributor

Hi @Ayenem , thanks for this feature. I was curious to look into this as an early feature to see how this works on my domain data but had some issues with some generation using the example data provided (stacktrace attached below). Could you share some example script how to test it?

Traceback (most recent call last):
  File "token_heal.py", line 33, in <module>
    output = model.generate(
  File "/home/leonardo/.local/lib/python3.8/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/leonardo/projects/transformers/src/transformers/generation/utils.py", line 1439, in generate
    input_ids = self.heal_tokens(input_ids, tokenizer)
  File "/home/leonardo/projects/transformers/src/transformers/generation/utils.py", line 1861, in heal_tokens
    seq_bias[(tail_id,)] += 1.0
KeyError: (518,)

Environment used:

  • transformers: I'm checked out at the head of your fork, this specific commit
  • pygtrie version: 2.5.0
  • Python version: 3.8.10
  • Model: AutoModelForCausalLM("meta-llama/Llama-2-7b-hf")

@ahmed-moubtahij ahmed-moubtahij deleted the heal_tokens branch March 28, 2024 21:24
@ahmed-moubtahij ahmed-moubtahij mentioned this pull request Apr 6, 2024
5 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Token healing (under 40 LOC)
6 participants