-
Notifications
You must be signed in to change notification settings - Fork 27.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
token healing impl #29081
token healing impl #29081
Changes from 6 commits
d17584e
fb8f187
2a146b6
07fa23f
762ebb9
c777787
249a8f7
43cb7ce
27d3ca0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,4 +7,5 @@ seqeval | |
tensorboard | ||
evaluate >= 0.2.0 | ||
torch | ||
accelerate | ||
accelerate | ||
pygtrie >= 2.5.0 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
<!-- back to top link --> | ||
<a name="readme-top"></a> | ||
|
||
<!-- ABOUT THE PROJECT --> | ||
## What is token healing? | ||
|
||
Token healing rectifies the token boundary bias in greedy tokenization. It does this by trimming and regrowing the prompt to better align with the model's tokenizer, thus enhancing generation quality. The improvement is clearest with completion models. | ||
|
||
Example: given a completion prompt with a partial url ending with `:`, the model might have seen the expected completion `://` as a _single_ token in training. However, the prompt's tail token `:` tells it that the next token is not `//`, and so it looks for wrong completions. Such errors compound in auto-regressive language models. | ||
|
||
Debiasing token boundaries also addresses output sensitivity to prompts ending with whitespace. | ||
|
||
A more thorough explanation can be found on [The Art of Prompt Design: Prompt Boundaries and Token Healing | by Scott Lundberg](https://towardsdatascience.com/the-art-of-prompt-design-prompt-boundaries-and-token-healing-3b2448b0be38). | ||
|
||
## Installation | ||
|
||
`pip install transformers pygtrie`. | ||
|
||
## Usage | ||
|
||
```py | ||
prompt = 'The link is <a href="http:' | ||
raw_output = generate(prompt, completion_model, tokenizer, token_healing=False) | ||
# The link is <a href="http://www/dailymail&# | ||
|
||
# The model saw '://' as a single token in training. Seeing a prompt ending with `:` tells it that the | ||
# next token is likely not `//`, because otherwise it would've seen `://`. | ||
# Thus, it completes with a token other than `//`, in this case, `&`. | ||
|
||
healed_output = generate(prompt, completion_model, tokenizer, token_healing=True) | ||
# The link is <a href="http://www.365doki.com/post/3699 | ||
|
||
# You can also use token healing in isolation | ||
# This can be useful if you have other work to do before the generation | ||
# Or if you want to delegate generation to another process | ||
input_ids = tokenizer(test_prompts, return_tensors='pt', padding=True).input_ids.cuda() | ||
healed_ids = model.heal_tokens(input_ids) | ||
healed_prompts = tokenizer.batch_decode(healed_ids, skip_special_tokens=True) | ||
# outputs the healed prompts without further completion/generation | ||
``` | ||
|
||
See `run_token_healing.py` for the full example. | ||
|
||
<p align="right">(<a href="#readme-top">back to top</a>)</p> |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
pygtrie >= 2.5.0 | ||
transformers >= 4.36.2 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig | ||
|
||
|
||
def generate(inputs, model, tokenizer, token_healing): | ||
input_ids = tokenizer(inputs, return_tensors="pt", padding=True).input_ids.cuda() | ||
generation_config = GenerationConfig( | ||
max_new_tokens=8, | ||
token_healing=token_healing, | ||
pad_token_id=model.config.pad_token_id, | ||
repetition_penalty=1.1, | ||
) | ||
output = model.generate(inputs=input_ids, generation_config=generation_config) | ||
return tokenizer.batch_decode(output, skip_special_tokens=True) | ||
|
||
|
||
model_name_or_path = "TheBloke/deepseek-llm-7B-base-GPTQ" | ||
completion_model = AutoModelForCausalLM.from_pretrained( | ||
model_name_or_path, | ||
device_map="auto", | ||
trust_remote_code=False, | ||
revision="main", | ||
use_cache=True, | ||
) | ||
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True) | ||
|
||
test_prompts = [ | ||
'An example ["like this"] and another example [', | ||
'The link is <a href="http:', | ||
'The link is <a href="http', # test aggressive healing http->https | ||
"I read a book about ", # test trailing whitespace | ||
"I read a book about", # test nothing to heal | ||
] | ||
|
||
raw_output = generate(test_prompts, completion_model, tokenizer, token_healing=False) | ||
healed_output = generate(test_prompts, completion_model, tokenizer, token_healing=True) | ||
|
||
for p, a, b in zip(test_prompts, raw_output, healed_output): | ||
print(f"\nPrompt: {p}\nWithout healing:\n{a}\nWith healing:\n{b}") | ||
|
||
# You can also use token healing in isolation | ||
# This can be useful if you have other work to do before the generation | ||
# Or if you want to delegate generation to another process | ||
input_ids = tokenizer(test_prompts, return_tensors="pt", padding=True).input_ids.cuda() | ||
healed_ids = completion_model.heal_tokens(input_ids) | ||
healed_prompts = tokenizer.batch_decode(healed_ids, skip_special_tokens=True) | ||
print("\nhealed prompts:") | ||
for p in healed_prompts: | ||
print(p) |
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -22,6 +22,7 @@ | |||||||||
|
||||||||||
import torch | ||||||||||
import torch.distributed as dist | ||||||||||
from pygtrie import CharTrie | ||||||||||
from torch import nn | ||||||||||
|
||||||||||
from ..cache_utils import Cache, DynamicCache, StaticCache | ||||||||||
|
@@ -85,6 +86,7 @@ | |||||||||
|
||||||||||
if TYPE_CHECKING: | ||||||||||
from ..modeling_utils import PreTrainedModel | ||||||||||
from ..tokenization_utils_base import PreTrainedTokenizerBase | ||||||||||
from .streamers import BaseStreamer | ||||||||||
|
||||||||||
logger = logging.get_logger(__name__) | ||||||||||
|
@@ -1323,6 +1325,7 @@ def generate( | |||||||||
synced_gpus = True | ||||||||||
else: | ||||||||||
synced_gpus = False | ||||||||||
tokenizer = kwargs.pop("tokenizer", None) | ||||||||||
|
||||||||||
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call | ||||||||||
self._validate_model_class() | ||||||||||
|
@@ -1432,6 +1435,9 @@ def generate( | |||||||||
else: | ||||||||||
input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids") | ||||||||||
|
||||||||||
if generation_config.token_healing: | ||||||||||
input_ids = self.heal_tokens(input_ids, tokenizer) | ||||||||||
|
||||||||||
if streamer is not None: | ||||||||||
streamer.put(input_ids.cpu()) | ||||||||||
|
||||||||||
|
@@ -1803,6 +1809,67 @@ def typeerror(): | |||||||||
|
||||||||||
return result | ||||||||||
|
||||||||||
def heal_tokens( | ||||||||||
self, input_ids: torch.LongTensor, tokenizer: Optional["PreTrainedTokenizerBase"] = None | ||||||||||
) -> torch.LongTensor: | ||||||||||
r""" | ||||||||||
Generates sequences of token ids for models with a language modeling head. | ||||||||||
Parameters: | ||||||||||
input_ids (`torch.LongTensor`): The sequence used as a prompt for the generation. | ||||||||||
tokenizer (`PreTrainedTokenizerBase`, *optional*): The tokenizer used to decode the input ids. | ||||||||||
Return: | ||||||||||
`torch.LongTensor` where each sequence has its tail token replaced with its appropriate extension. | ||||||||||
""" | ||||||||||
if tokenizer is None: | ||||||||||
Comment on lines
+1826
to
+1827
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
we also need to make sure this function errors out correctly if used |
||||||||||
raise ValueError( | ||||||||||
" When generating with token healing, you must pass the model's tokenizer to the `tokenizer` " | ||||||||||
"argument of `generate`." | ||||||||||
) | ||||||||||
bos_id, pad_id = tokenizer.bos_token_id, tokenizer.pad_token_id | ||||||||||
vocab_trie = CharTrie(tokenizer.get_vocab()) | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. BTW we have https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_utils.py#L52 which could be used for this? Would remove the dependency? (It's might be additional work as well) |
||||||||||
gen_cfg = GenerationConfig(max_new_tokens=1, pad_token_id=pad_id) | ||||||||||
|
||||||||||
# assumption: leading/trailing whitespace is not meaningful, so the prompts are | ||||||||||
# stripped before re-tokenizing to desensitize generation to whitespace artefacts | ||||||||||
prompts = [p.strip() for p in tokenizer.batch_decode(input_ids, skip_special_tokens=True)] | ||||||||||
input_ids = tokenizer( | ||||||||||
prompts, | ||||||||||
return_tensors="pt", | ||||||||||
padding=True, | ||||||||||
).input_ids.to(input_ids.device) | ||||||||||
|
||||||||||
# replace bos with pad to not condition healing on it | ||||||||||
input_ids = torch.where(input_ids == bos_id, pad_id, input_ids) | ||||||||||
|
||||||||||
tail_ids = input_ids[:, -1].tolist() | ||||||||||
space_tok = tokenizer.tokenize(" ")[0] | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not 100% sure this will always do what you want, specifically for tokenizer that add a prefix token you could get [ |
||||||||||
# tail tokens are used for a prefix search, thus, whitespaces are replaced with | ||||||||||
# their tokenization (e.g. 'Ġ') to enable search for tokens prefixed with a whitespace | ||||||||||
tail_toks = (tokenizer.decode(t).replace(" ", space_tok) for t in tail_ids) | ||||||||||
|
||||||||||
for batch_idx, (tail_id, tail_tok) in enumerate(zip(tail_ids, tail_toks)): | ||||||||||
batch_ids = input_ids[batch_idx] | ||||||||||
if torch.all(batch_ids == pad_id).item(): | ||||||||||
continue # skip empty sequences (all pad ids) | ||||||||||
|
||||||||||
# apply bias for alternatives (extensions) to the tail token | ||||||||||
seq_bias = {(alt_tok,): 10.0 for alt_tok in vocab_trie.values(prefix=tail_tok)} | ||||||||||
if len(seq_bias) == 1: | ||||||||||
continue # skip if there are no token alternatives to heal with | ||||||||||
|
||||||||||
# slightly favor original token to limit aggressive healing e.g. 'http' -> 'https' | ||||||||||
seq_bias[(tail_id,)] += 1.0 | ||||||||||
gen_cfg.update(sequence_bias=seq_bias) | ||||||||||
|
||||||||||
trimmed_ids = batch_ids[:-1] | ||||||||||
# if the prompt is a single (non-pad) token, regenerate from bos | ||||||||||
if len(batch_ids[batch_ids != pad_id]) == 1: | ||||||||||
trimmed_ids[-1] = bos_id | ||||||||||
|
||||||||||
input_ids[batch_idx] = self.generate(trimmed_ids.unsqueeze(0), generation_config=gen_cfg) | ||||||||||
|
||||||||||
return input_ids | ||||||||||
|
||||||||||
def contrastive_search(self, *args, **kwargs): | ||||||||||
logger.warning_once( | ||||||||||
"Calling `contrastive_search` directly is deprecated and will be removed in v4.41. Use `generate` or a " | ||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if this is an optinal dependency we need to protect the import