Skip to content

Commit

Permalink
change sequence_bias type of SequenceBiasLogitsProcessor to list, add… (
Browse files Browse the repository at this point in the history
#33375)

* change sequence_bias type of SequenceBiasLogitsProcessor tp list, add config tests for all processors

* fix format

* small fix for all_token_bias_pairs_are_valid internal func

* small typo fix in description

* improve test impl, some SequenceBiasLogitsProcessor refactoring
  • Loading branch information
VladOS95-cyber authored Sep 19, 2024
1 parent d9d59e7 commit 162056a
Show file tree
Hide file tree
Showing 2 changed files with 486 additions and 16 deletions.
53 changes: 40 additions & 13 deletions src/transformers/generation/logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import inspect
import math
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union
from typing import Callable, Iterable, List, Optional, Tuple, Union

import numpy as np
import torch
Expand Down Expand Up @@ -1064,8 +1064,9 @@ class SequenceBiasLogitsProcessor(LogitsProcessor):
</Tip>
Args:
sequence_bias (`Dict[Tuple[int], float]`):
Dictionary that maps a sequence of tokens to its bias term. Positive biases increase the odds of the
sequence_bias (`List[List[Union[List[int], float]]]`):
List of lists that maps a sequence of tokens to its bias term (e.g. `[[[10, 45], -2.0],
[[64], -7.5]]`). Positive biases increase the odds of the
sequence being selected, while negative biases do the opposite. If a sequence has a length of 1, its bias
will always be applied. Otherwise, the bias will only be applied if the sequence in question is about to be
completed (in the token selection step after this processor is applied).
Expand All @@ -1087,12 +1088,12 @@ class SequenceBiasLogitsProcessor(LogitsProcessor):
>>> tokenizer_with_prefix_space = AutoTokenizer.from_pretrained("openai-community/gpt2", add_prefix_space=True)
>>> def get_tokens_as_tuple(word):
... return tuple(tokenizer_with_prefix_space([word], add_special_tokens=False).input_ids[0])
>>> def get_tokens(word):
... return tokenizer_with_prefix_space([word], add_special_tokens=False).input_ids[0]
>>> # If we add a negative bias without beam search, it may become "stuck" in a prefix without good continuations
>>> sequence_bias = {get_tokens_as_tuple("Trump"): -10.0}
>>> sequence_bias = [get_tokens("Trump"), -10.0]
>>> biased_ids = model.generate(inputs["input_ids"], max_new_tokens=4, sequence_bias=sequence_bias)
>>> print(tokenizer.batch_decode(biased_ids, skip_special_tokens=True)[0])
The full name of Donald is Donald J. Donald,
Expand All @@ -1102,16 +1103,17 @@ class SequenceBiasLogitsProcessor(LogitsProcessor):
The full name of Donald is Donald Rumsfeld,
>>> # We can also add a positive bias to nudge the model towards specific tokens or continuations
>>> sequence_bias = {get_tokens_as_tuple("Donald Duck"): 10.0}
>>> sequence_bias = [get_tokens("Donald Duck"), 10.0]
>>> biased_ids = model.generate(inputs["input_ids"], max_new_tokens=4, num_beams=4, sequence_bias=sequence_bias)
>>> print(tokenizer.batch_decode(biased_ids, skip_special_tokens=True)[0])
The full name of Donald is Donald Duck.
```
"""

def __init__(self, sequence_bias: Dict[Tuple[int], float]):
def __init__(self, sequence_bias: List[List[Union[List[int], float]]]):
self.sequence_bias = sequence_bias
self._validate_arguments()
self._convert_list_arguments_into_dict()

# Bias variables that will be populated on the first call (for retrocompatibility purposes, the vocabulary size
# is infered in the first usage, which inhibits initializing here)
Expand Down Expand Up @@ -1178,11 +1180,15 @@ def _prepare_bias_variables(self, scores: torch.FloatTensor):

def _validate_arguments(self):
sequence_bias = self.sequence_bias
if not isinstance(sequence_bias, dict) or len(sequence_bias) == 0:
raise ValueError(f"`sequence_bias` has to be a non-empty dictionary, but is {sequence_bias}.")
if any(not isinstance(sequence_ids, tuple) for sequence_ids in sequence_bias.keys()):
if not isinstance(sequence_bias, dict) and not isinstance(sequence_bias, list) or len(sequence_bias) == 0:
raise ValueError(
f"`sequence_bias` has to be a non-empty dictionary, or non-empty list of lists but is {sequence_bias}."
)
if isinstance(sequence_bias, dict) and any(
not isinstance(sequence_ids, tuple) for sequence_ids in sequence_bias.keys()
):
raise ValueError(f"`sequence_bias` has to be a dict with tuples as keys, but is {sequence_bias}.")
if any(
if isinstance(sequence_bias, dict) and any(
any((not isinstance(token_id, (int, np.integer)) or token_id < 0) for token_id in sequence_ids)
or len(sequence_ids) == 0
for sequence_ids in sequence_bias.keys()
Expand All @@ -1191,9 +1197,30 @@ def _validate_arguments(self):
f"Each key in `sequence_bias` has to be a non-empty tuple of positive integers, but is "
f"{sequence_bias}."
)
if any(not isinstance(bias, float) for bias in sequence_bias.values()):

def all_token_bias_pairs_are_valid(sequence):
return (
isinstance(sequence[0], list)
and all(isinstance(token_id, (int, np.integer)) and token_id > 0 for token_id in sequence[0])
and isinstance(sequence[1], float)
)

if isinstance(sequence_bias, list) and any(
(not all_token_bias_pairs_are_valid(sequence)) or len(sequence) == 0 for sequence in sequence_bias
):
raise ValueError(
f"Each element in `sequence_bias` has to be a non-empty list of lists of positive integers and float, but is "
f"{sequence_bias}."
)
if isinstance(sequence_bias, dict) and any(not isinstance(bias, float) for bias in sequence_bias.values()):
raise ValueError(f"`sequence_bias` has to be a dict with floats as values, but is {sequence_bias}.")

def _convert_list_arguments_into_dict(self):
"""BC: we used to accept `dict{tuple of tokens: float}` directly, now we expect a list"""
if isinstance(self.sequence_bias, list):
temp_sequence = self.sequence_bias
self.sequence_bias = {tuple(sublist[0]): sublist[1] for sublist in temp_sequence}


class NoBadWordsLogitsProcessor(SequenceBiasLogitsProcessor):
"""
Expand Down
Loading

0 comments on commit 162056a

Please sign in to comment.