Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

token healing impl #29081

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion examples/flax/_tests_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ seqeval
tensorboard
evaluate >= 0.2.0
torch
accelerate
accelerate
pygtrie >= 2.5.0
44 changes: 44 additions & 0 deletions examples/research_projects/token-healing/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
<!-- back to top link -->
<a name="readme-top"></a>

<!-- ABOUT THE PROJECT -->
## What is token healing?

Token healing rectifies the token boundary bias in greedy tokenization. It does this by trimming and regrowing the prompt to better align with the model's tokenizer, thus enhancing generation quality. The improvement is clearest with completion models.

Example: given a completion prompt with a partial url ending with `:`, the model might have seen the expected completion `://` as a _single_ token in training. However, the prompt's tail token `:` tells it that the next token is not `//`, and so it looks for wrong completions. Such errors compound in auto-regressive language models.

Debiasing token boundaries also addresses output sensitivity to prompts ending with whitespace.

A more thorough explanation can be found on [The Art of Prompt Design: Prompt Boundaries and Token Healing | by Scott Lundberg](https://towardsdatascience.com/the-art-of-prompt-design-prompt-boundaries-and-token-healing-3b2448b0be38).

## Installation

`pip install transformers pygtrie`.

## Usage

```py
prompt = 'The link is <a href="http:'
raw_output = generate(prompt, completion_model, tokenizer, token_healing=False)
# The link is <a href="http:&#47;&#47;www&#47;dailymail&#

# The model saw '://' as a single token in training. Seeing a prompt ending with `:` tells it that the
# next token is likely not `//`, because otherwise it would've seen `://`.
# Thus, it completes with a token other than `//`, in this case, `&`.

healed_output = generate(prompt, completion_model, tokenizer, token_healing=True)
# The link is <a href="http://www.365doki.com/post/3699

# You can also use token healing in isolation
# This can be useful if you have other work to do before the generation
# Or if you want to delegate generation to another process
input_ids = tokenizer(test_prompts, return_tensors='pt', padding=True).input_ids.cuda()
healed_ids = model.heal_tokens(input_ids)
healed_prompts = tokenizer.batch_decode(healed_ids, skip_special_tokens=True)
# outputs the healed prompts without further completion/generation
```

See `run_token_healing.py` for the full example.

<p align="right">(<a href="#readme-top">back to top</a>)</p>
2 changes: 2 additions & 0 deletions examples/research_projects/token-healing/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
pygtrie >= 2.5.0
transformers >= 4.36.2
48 changes: 48 additions & 0 deletions examples/research_projects/token-healing/run_token_healing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig


def generate(inputs, model, tokenizer, token_healing):
input_ids = tokenizer(inputs, return_tensors="pt", padding=True).input_ids.cuda()
generation_config = GenerationConfig(
max_new_tokens=8,
token_healing=token_healing,
pad_token_id=model.config.pad_token_id,
repetition_penalty=1.1,
)
output = model.generate(inputs=input_ids, generation_config=generation_config)
return tokenizer.batch_decode(output, skip_special_tokens=True)


model_name_or_path = "TheBloke/deepseek-llm-7B-base-GPTQ"
completion_model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
device_map="auto",
trust_remote_code=False,
revision="main",
use_cache=True,
)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True)

test_prompts = [
'An example ["like this"] and another example [',
'The link is <a href="http:',
'The link is <a href="http', # test aggressive healing http->https
"I read a book about ", # test trailing whitespace
"I read a book about", # test nothing to heal
]

raw_output = generate(test_prompts, completion_model, tokenizer, token_healing=False)
healed_output = generate(test_prompts, completion_model, tokenizer, token_healing=True)

for p, a, b in zip(test_prompts, raw_output, healed_output):
print(f"\nPrompt: {p}\nWithout healing:\n{a}\nWith healing:\n{b}")

# You can also use token healing in isolation
# This can be useful if you have other work to do before the generation
# Or if you want to delegate generation to another process
input_ids = tokenizer(test_prompts, return_tensors="pt", padding=True).input_ids.cuda()
healed_ids = completion_model.heal_tokens(input_ids)
healed_prompts = tokenizer.batch_decode(healed_ids, skip_special_tokens=True)
print("\nhealed prompts:")
for p in healed_prompts:
print(p)
4 changes: 4 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@
"psutil",
"pyyaml>=5.1",
"pydantic",
"pygtrie>=2.5.0",
"pytest>=7.2.0,<8.0.0",
"pytest-timeout",
"pytest-xdist",
Expand Down Expand Up @@ -269,6 +270,7 @@ def run(self):
extras["retrieval"] = deps_list("faiss-cpu", "datasets")
extras["flax"] = deps_list("jax", "jaxlib", "flax", "optax")

extras["generate"] = deps_list("pygtrie")
extras["tokenizers"] = deps_list("tokenizers")
extras["ftfy"] = deps_list("ftfy")
extras["onnxruntime"] = deps_list("onnxruntime", "onnxruntime-tools")
Expand Down Expand Up @@ -324,6 +326,7 @@ def run(self):
)
+ extras["retrieval"]
+ extras["modelcreation"]
+ extras["generate"]
)

extras["deepspeed-testing"] = extras["deepspeed"] + extras["testing"] + extras["optuna"] + extras["sentencepiece"]
Expand All @@ -344,6 +347,7 @@ def run(self):
+ extras["codecarbon"]
+ extras["accelerate"]
+ extras["video"]
+ extras["generate"]
)

# Might need to add doc-builder and some specific deps in the future
Expand Down
1 change: 1 addition & 0 deletions src/transformers/dependency_versions_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
"psutil": "psutil",
"pyyaml": "pyyaml>=5.1",
"pydantic": "pydantic",
"pygtrie": "pygtrie>=2.5.0",
"pytest": "pytest>=7.2.0,<8.0.0",
"pytest-timeout": "pytest-timeout",
"pytest-xdist": "pytest-xdist",
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/generation/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,9 @@ class GenerationConfig(PushToHubMixin):
Dictionary that maps a sequence of tokens to its bias term. Positive biases increase the odds of the
sequence being selected, while negative biases do the opposite. Check
[`~generation.SequenceBiasLogitsProcessor`] for further documentation and examples.
token_healing (`bool`, *optional*, defaults to `False`):
Heal tail tokens of prompts by replacing them with their appropriate extensions.
This enhances the quality of completions for prompts affected by greedy tokenization bias.
guidance_scale (`float`, *optional*):
The guidance scale for classifier free guidance (CFG). CFG is enabled by setting `guidance_scale > 1`.
Higher guidance scale encourages the model to generate samples that are more closely linked to the input
Expand Down Expand Up @@ -309,6 +312,7 @@ def __init__(self, **kwargs):
self.begin_suppress_tokens = kwargs.pop("begin_suppress_tokens", None)
self.forced_decoder_ids = kwargs.pop("forced_decoder_ids", None)
self.sequence_bias = kwargs.pop("sequence_bias", None)
self.token_healing = kwargs.pop("token_healing", False)
self.guidance_scale = kwargs.pop("guidance_scale", None)
self.low_memory = kwargs.pop("low_memory", None)

Expand Down
67 changes: 67 additions & 0 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import torch
import torch.distributed as dist
from pygtrie import CharTrie
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if this is an optinal dependency we need to protect the import

from torch import nn

from ..cache_utils import Cache, DynamicCache, StaticCache
Expand Down Expand Up @@ -85,6 +86,7 @@

if TYPE_CHECKING:
from ..modeling_utils import PreTrainedModel
from ..tokenization_utils_base import PreTrainedTokenizerBase
from .streamers import BaseStreamer

logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -1323,6 +1325,7 @@ def generate(
synced_gpus = True
else:
synced_gpus = False
tokenizer = kwargs.pop("tokenizer", None)

# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
self._validate_model_class()
Expand Down Expand Up @@ -1432,6 +1435,9 @@ def generate(
else:
input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")

if generation_config.token_healing:
input_ids = self.heal_tokens(input_ids, tokenizer)

if streamer is not None:
streamer.put(input_ids.cpu())

Expand Down Expand Up @@ -1803,6 +1809,67 @@ def typeerror():

return result

def heal_tokens(
self, input_ids: torch.LongTensor, tokenizer: Optional["PreTrainedTokenizerBase"] = None
) -> torch.LongTensor:
r"""
Generates sequences of token ids for models with a language modeling head.
Parameters:
input_ids (`torch.LongTensor`): The sequence used as a prompt for the generation.
tokenizer (`PreTrainedTokenizerBase`, *optional*): The tokenizer used to decode the input ids.
Return:
`torch.LongTensor` where each sequence has its tail token replaced with its appropriate extension.
"""
if tokenizer is None:
Comment on lines +1826 to +1827
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""
if tokenizer is None:
"""
requires_backends(self, ["pygtrie"])

we also need to make sure this function errors out correctly if used

raise ValueError(
" When generating with token healing, you must pass the model's tokenizer to the `tokenizer` "
"argument of `generate`."
)
bos_id, pad_id = tokenizer.bos_token_id, tokenizer.pad_token_id
vocab_trie = CharTrie(tokenizer.get_vocab())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW we have https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_utils.py#L52

which could be used for this? Would remove the dependency? (It's might be additional work as well)

gen_cfg = GenerationConfig(max_new_tokens=1, pad_token_id=pad_id)

# assumption: leading/trailing whitespace is not meaningful, so the prompts are
# stripped before re-tokenizing to desensitize generation to whitespace artefacts
prompts = [p.strip() for p in tokenizer.batch_decode(input_ids, skip_special_tokens=True)]
input_ids = tokenizer(
prompts,
return_tensors="pt",
padding=True,
).input_ids.to(input_ids.device)

# replace bos with pad to not condition healing on it
input_ids = torch.where(input_ids == bos_id, pad_id, input_ids)

tail_ids = input_ids[:, -1].tolist()
space_tok = tokenizer.tokenize(" ")[0]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not 100% sure this will always do what you want, specifically for tokenizer that add a prefix token you could get [▁▁]

# tail tokens are used for a prefix search, thus, whitespaces are replaced with
# their tokenization (e.g. 'Ġ') to enable search for tokens prefixed with a whitespace
tail_toks = (tokenizer.decode(t).replace(" ", space_tok) for t in tail_ids)

for batch_idx, (tail_id, tail_tok) in enumerate(zip(tail_ids, tail_toks)):
batch_ids = input_ids[batch_idx]
if torch.all(batch_ids == pad_id).item():
continue # skip empty sequences (all pad ids)

# apply bias for alternatives (extensions) to the tail token
seq_bias = {(alt_tok,): 10.0 for alt_tok in vocab_trie.values(prefix=tail_tok)}
if len(seq_bias) == 1:
continue # skip if there are no token alternatives to heal with

# slightly favor original token to limit aggressive healing e.g. 'http' -> 'https'
seq_bias[(tail_id,)] += 1.0
gen_cfg.update(sequence_bias=seq_bias)

trimmed_ids = batch_ids[:-1]
# if the prompt is a single (non-pad) token, regenerate from bos
if len(batch_ids[batch_ids != pad_id]) == 1:
trimmed_ids[-1] = bos_id

input_ids[batch_idx] = self.generate(trimmed_ids.unsqueeze(0), generation_config=gen_cfg)

return input_ids

def contrastive_search(self, *args, **kwargs):
logger.warning_once(
"Calling `contrastive_search` directly is deprecated and will be removed in v4.41. Use `generate` or a "
Expand Down
37 changes: 37 additions & 0 deletions tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from transformers.testing_utils import (
is_flaky,
require_accelerate,
require_auto_gptq,
require_torch,
require_torch_multi_accelerator,
slow,
Expand Down Expand Up @@ -3638,3 +3639,39 @@ def test_return_unprocessed_logit_scores(self):

self.assertTrue(y_prob > 0.001 and n_prob > 0.001)
self.assertTrue(y_prob <= 1.0 and n_prob <= 1.0)


@require_torch
class TokenHealingTestCase(unittest.TestCase):
@parameterized.expand(
[
(
"square_bracket",
'An example ["like this"] and another example [',
'An example ["like this"] and another example ["',
),
("url", 'The link is <a href="http:', 'The link is <a href="http://'),
("aggressive_healing", 'The link is <a href="http', 'The link is <a href="http'),
("trailing_whitespace", "I read a book about ", "I read a book about"),
("nothing_to_heal", "I read a book about", "I read a book about"),
("single_token", "I", "I"),
("empty_prompt", "", ""),
]
)
@require_auto_gptq
def test_prompts(self, name, input, expected):
model_name_or_path = "TheBloke/deepseek-llm-7B-base-GPTQ"
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True)
completion_model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
device_map="auto",
trust_remote_code=False,
revision="main",
use_cache=True,
)
input_ids = tokenizer(input, return_tensors="pt").input_ids.to(completion_model.device)

healed_ids = completion_model.heal_tokens(input_ids, tokenizer)
predicted = tokenizer.decode(healed_ids[0], skip_special_tokens=True)

self.assertEqual(predicted, expected)
Loading