Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate: consistently handle special tokens as tensors #30624

Merged
merged 9 commits into from
May 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 24 additions & 16 deletions src/transformers/generation/beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,8 +218,8 @@ def process(
next_scores: torch.FloatTensor,
next_tokens: torch.LongTensor,
next_indices: torch.LongTensor,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[Union[int, List[int]]] = None,
pad_token_id: Optional[Union[int, torch.Tensor]] = None,
eos_token_id: Optional[Union[int, List[int], torch.Tensor]] = None,
beam_indices: Optional[torch.LongTensor] = None,
group_index: Optional[int] = 0,
decoder_prompt_len: Optional[int] = 0,
Expand All @@ -245,8 +245,10 @@ def process(
next_beam_tokens = torch.zeros((batch_size, self.group_size), dtype=next_tokens.dtype, device=device)
next_beam_indices = torch.zeros((batch_size, self.group_size), dtype=next_indices.dtype, device=device)

if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
if eos_token_id is not None and not isinstance(eos_token_id, torch.Tensor):
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id = torch.tensor(eos_token_id)

for batch_idx in range(batch_size):
batch_group_idx = batch_idx * self.num_beam_groups + group_index
Expand Down Expand Up @@ -322,15 +324,17 @@ def finalize(
final_beam_tokens: torch.LongTensor,
final_beam_indices: torch.LongTensor,
max_length: int,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[Union[int, List[int]]] = None,
pad_token_id: Optional[Union[int, torch.Tensor]] = None,
eos_token_id: Optional[Union[int, List[int], torch.Tensor]] = None,
beam_indices: Optional[torch.LongTensor] = None,
decoder_prompt_len: Optional[int] = 0,
) -> Tuple[torch.LongTensor]:
batch_size = len(self._beam_hyps) // self.num_beam_groups

if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
if eos_token_id is not None and not isinstance(eos_token_id, torch.Tensor):
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id = torch.tensor(eos_token_id)

# finalize all open beam hypotheses and add to generated hypotheses
for batch_group_idx, beam_hyp in enumerate(self._beam_hyps):
Expand Down Expand Up @@ -513,8 +517,8 @@ def process(
next_tokens: torch.LongTensor,
next_indices: torch.LongTensor,
scores_for_all_vocab: torch.FloatTensor,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[Union[int, List[int]]] = None,
pad_token_id: Optional[Union[int, torch.Tensor]] = None,
eos_token_id: Optional[Union[int, List[int], torch.Tensor]] = None,
beam_indices: Optional[torch.LongTensor] = None,
decoder_prompt_len: Optional[int] = 0,
) -> Tuple[torch.Tensor]:
Expand Down Expand Up @@ -578,8 +582,10 @@ def process(
next_beam_tokens = torch.zeros((batch_size, self.group_size), dtype=next_tokens.dtype, device=device)
next_beam_indices = torch.zeros((batch_size, self.group_size), dtype=next_indices.dtype, device=device)

if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
if eos_token_id is not None and not isinstance(eos_token_id, torch.Tensor):
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id = torch.tensor(eos_token_id)

for batch_idx, beam_hyp in enumerate(self._beam_hyps):
if self._done[batch_idx]:
Expand Down Expand Up @@ -811,15 +817,17 @@ def finalize(
final_beam_tokens: torch.LongTensor,
final_beam_indices: torch.LongTensor,
max_length: int,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[Union[int, List[int]]] = None,
pad_token_id: Optional[Union[int, torch.Tensor]] = None,
eos_token_id: Optional[Union[int, List[int], torch.Tensor]] = None,
beam_indices: Optional[torch.LongTensor] = None,
decoder_prompt_len: Optional[int] = 0,
) -> Tuple[torch.LongTensor]:
batch_size = len(self._beam_hyps)

if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
if eos_token_id is not None and not isinstance(eos_token_id, torch.Tensor):
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id = torch.tensor(eos_token_id)

# finalize all open beam hypotheses and add to generated hypotheses
for batch_idx, beam_hyp in enumerate(self._beam_hyps):
Expand Down
137 changes: 81 additions & 56 deletions src/transformers/generation/logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,8 @@ class MinLengthLogitsProcessor(LogitsProcessor):
Args:
min_length (`int`):
The minimum length below which the score of `eos_token_id` is set to `-float("Inf")`.
eos_token_id (`Union[int, List[int]]`):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
eos_token_id (`Union[int, List[int], torch.Tensor]`):
The id(s) of the *end-of-sequence* token.

Examples:

Expand Down Expand Up @@ -137,23 +137,23 @@ class MinLengthLogitsProcessor(LogitsProcessor):
```
"""

def __init__(self, min_length: int, eos_token_id: Union[int, List[int]]):
def __init__(self, min_length: int, eos_token_id: Union[int, List[int], torch.Tensor]):
if not isinstance(min_length, int) or min_length < 0:
raise ValueError(f"`min_length` has to be a non-negative integer, but is {min_length}")

if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
if not all(isinstance(i, int) for i in eos_token_id) or any(i < 0 for i in eos_token_id):
logger.warning(f"`eos_token_id` has to be a list of positive integers, but is {eos_token_id}")
if not isinstance(eos_token_id, torch.Tensor):
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id = torch.tensor(eos_token_id)

self.min_length = min_length
self.eos_token_id = eos_token_id

@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
vocab_tensor = torch.arange(scores.shape[-1], device=scores.device)
eos_token_id = torch.tensor(self.eos_token_id, device=scores.device)
eos_token_mask = torch.isin(vocab_tensor, eos_token_id)
self.eos_token_id = self.eos_token_id.to(scores.device)
eos_token_mask = torch.isin(vocab_tensor, self.eos_token_id)
ArthurZucker marked this conversation as resolved.
Show resolved Hide resolved
scores_processed = scores.clone()
if input_ids.shape[-1] < self.min_length:
scores_processed = torch.where(eos_token_mask, -math.inf, scores)
Expand All @@ -171,8 +171,8 @@ class MinNewTokensLengthLogitsProcessor(LogitsProcessor):
input length.
min_new_tokens (`int`):
The minimum *new* tokens length below which the score of `eos_token_id` is set to `-float("Inf")`.
eos_token_id (`Union[int, List[int]]`):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
eos_token_id (`Union[int, List[int], torch.Tensor]`):
The id(s) of the *end-of-sequence* token.

Examples:

Expand All @@ -195,18 +195,20 @@ class MinNewTokensLengthLogitsProcessor(LogitsProcessor):
```
"""

def __init__(self, prompt_length_to_skip: int, min_new_tokens: int, eos_token_id: Union[int, List[int]]):
def __init__(
self, prompt_length_to_skip: int, min_new_tokens: int, eos_token_id: Union[int, List[int], torch.Tensor]
):
for arg_name, arg_value in [
("prompt_length_to_skip", prompt_length_to_skip),
("min_new_tokens", min_new_tokens),
]:
if not isinstance(arg_value, int) or arg_value < 0:
raise ValueError(f"`{arg_name}` has to be a positive integer, but is {arg_value}")

if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
if not all(isinstance(i, int) for i in eos_token_id) or any(i < 0 for i in eos_token_id):
logger.warning(f"`eos_token_id` has to be a list of positive integers, but is {eos_token_id}")
if not isinstance(eos_token_id, torch.Tensor):
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id = torch.tensor(eos_token_id)

self.prompt_length_to_skip = prompt_length_to_skip
self.min_new_tokens = min_new_tokens
Expand All @@ -217,8 +219,8 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
new_tokens_length = input_ids.shape[-1] - self.prompt_length_to_skip
scores_processed = scores.clone()
vocab_tensor = torch.arange(scores.shape[-1], device=scores.device)
eos_token_id = torch.tensor(self.eos_token_id, device=scores.device)
eos_token_mask = torch.isin(vocab_tensor, eos_token_id)
self.eos_token_id = self.eos_token_id.to(scores.device)
eos_token_mask = torch.isin(vocab_tensor, self.eos_token_id)
if new_tokens_length < self.min_new_tokens:
scores_processed = torch.where(eos_token_mask, -math.inf, scores)

Expand Down Expand Up @@ -1118,8 +1120,8 @@ class NoBadWordsLogitsProcessor(SequenceBiasLogitsProcessor):
Args:
bad_words_ids (`List[List[int]]`):
List of list of token ids that are not allowed to be generated.
eos_token_id (`Union[int, List[int]]`):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
eos_token_id (`Union[int, List[int], torch.Tensor]`, *optional*):
The id(s) of the *end-of-sequence* token.

Examples:

Expand Down Expand Up @@ -1156,18 +1158,22 @@ class NoBadWordsLogitsProcessor(SequenceBiasLogitsProcessor):
```
"""

def __init__(self, bad_words_ids: List[List[int]], eos_token_id: Union[int, List[int]]):
def __init__(
self, bad_words_ids: List[List[int]], eos_token_id: Optional[Union[int, List[int], torch.Tensor]] = None
):
self.bad_word_ids = bad_words_ids
self._validate_arguments()

# Filter EOS token from bad_words_ids
if eos_token_id is None:
eos_token_id = []
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
bad_words_ids = list(
filter(lambda bad_token_seq: all(bad_token_seq != [i] for i in eos_token_id), bad_words_ids)
)
if eos_token_id is not None:
if not isinstance(eos_token_id, torch.Tensor):
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id = torch.tensor(eos_token_id)

bad_words_ids = list(
filter(lambda bad_token_seq: all(bad_token_seq != [i] for i in eos_token_id), bad_words_ids)
ArthurZucker marked this conversation as resolved.
Show resolved Hide resolved
)

# Forbidding a sequence is equivalent to setting its bias to -inf
sequence_bias = {tuple(sequence): float("-inf") for sequence in bad_words_ids}
Expand Down Expand Up @@ -1445,9 +1451,8 @@ class ForcedEOSTokenLogitsProcessor(LogitsProcessor):
Args:
max_length (`int`):
The maximum length of the sequence to be generated.
eos_token_id (`Union[int, List[int]]`):
The id of the token to force as the last generated token when `max_length` is reached. Optionally, use a
list to set multiple *end-of-sequence* tokens.
eos_token_id (`Union[int, List[int], torch.Tensor]`):
The id(s) of the *end-of-sequence* token.

Examples:

Expand All @@ -1471,15 +1476,22 @@ class ForcedEOSTokenLogitsProcessor(LogitsProcessor):
```
"""

def __init__(self, max_length: int, eos_token_id: Union[int, List[int]]):
def __init__(self, max_length: int, eos_token_id: Union[int, List[int], torch.Tensor]):
self.max_length = max_length
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]

if not isinstance(eos_token_id, torch.Tensor):
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id = torch.tensor(eos_token_id)
self.eos_token_id = eos_token_id

if torch.is_floating_point(eos_token_id) or (eos_token_id < 0).any():
raise ValueError(f"`eos_token_id` has to be a list of positive integers, but is {eos_token_id}")

@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
cur_len = input_ids.shape[-1]
self.eos_token_id = self.eos_token_id.to(scores.device)
scores_processed = scores
if cur_len == self.max_length - 1:
scores_processed = torch.full_like(scores, -math.inf)
Expand Down Expand Up @@ -1518,8 +1530,8 @@ class ExponentialDecayLengthPenalty(LogitsProcessor):
exponential_decay_length_penalty (`tuple(int, float)`):
This tuple shall consist of: `(start_index, decay_factor)` where `start_index` indicates where penalty
starts and `decay_factor` represents the factor of exponential decay
eos_token_id (`Union[int, List[int]]`):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
eos_token_id (`Union[int, List[int], torch.Tensor]`):
The id(s) of the *end-of-sequence* token.
input_ids_seq_length (`int`):
The length of the input sequence.

Expand Down Expand Up @@ -1579,27 +1591,33 @@ class ExponentialDecayLengthPenalty(LogitsProcessor):
def __init__(
self,
exponential_decay_length_penalty: Tuple[int, float],
eos_token_id: Union[int, List[int]],
eos_token_id: Union[int, List[int], torch.Tensor],
input_ids_seq_length: int,
):
self.regulation_start = exponential_decay_length_penalty[0] + input_ids_seq_length
self.regulation_factor = exponential_decay_length_penalty[1]
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]

if not isinstance(eos_token_id, torch.Tensor):
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id = torch.tensor(eos_token_id)
self.eos_token_id = eos_token_id

if torch.is_floating_point(eos_token_id) or (eos_token_id < 0).any():
raise ValueError(f"`eos_token_id` has to be a list of positive integers, but is {eos_token_id}")

@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
cur_len = input_ids.shape[-1]
self.eos_token_id = self.eos_token_id.to(scores.device)
penalties = torch.zeros_like(scores)
scores_processed = scores
if cur_len > self.regulation_start:
for i in self.eos_token_id:
penalty_idx = cur_len - self.regulation_start
# To support negative logits we compute the penalty of the absolute value and add to the original logit
penalty = torch.abs(scores[:, i]) * (pow(self.regulation_factor, penalty_idx) - 1)
penalties[:, i] = penalty
scores_processed = scores + penalties
penalty_idx = cur_len - self.regulation_start
# To support negative logits we compute the penalty of the absolute value and add to the original logit
penalty = torch.abs(scores[:, self.eos_token_id]) * (pow(self.regulation_factor, penalty_idx) - 1)
penalties[:, self.eos_token_id] = penalty
scores_processed = scores + penalties
return scores_processed


Expand Down Expand Up @@ -1676,7 +1694,7 @@ class SuppressTokensAtBeginLogitsProcessor(LogitsProcessor):
"""

def __init__(self, begin_suppress_tokens, begin_index):
self.begin_suppress_tokens = list(begin_suppress_tokens)
self.begin_suppress_tokens = torch.tensor(list(begin_suppress_tokens))
self.begin_index = begin_index

def set_begin_index(self, begin_index):
Expand All @@ -1685,8 +1703,8 @@ def set_begin_index(self, begin_index):
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
vocab_tensor = torch.arange(scores.shape[-1], device=scores.device)
begin_suppress_tokens = torch.tensor(self.begin_suppress_tokens, device=scores.device)
suppress_token_mask = torch.isin(vocab_tensor, begin_suppress_tokens)
self.begin_suppress_tokens = self.begin_suppress_tokens.to(scores.device)
suppress_token_mask = torch.isin(vocab_tensor, self.begin_suppress_tokens)
scores_processed = scores
if input_ids.shape[-1] == self.begin_index:
scores_processed = torch.where(suppress_token_mask, -float("inf"), scores)
Expand Down Expand Up @@ -1724,13 +1742,13 @@ class SuppressTokensLogitsProcessor(LogitsProcessor):
"""

def __init__(self, suppress_tokens):
self.suppress_tokens = list(suppress_tokens)
self.suppress_tokens = torch.tensor(list(suppress_tokens))

@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
vocab_tensor = torch.arange(scores.shape[-1], device=scores.device)
suppress_tokens = torch.tensor(self.suppress_tokens, device=scores.device)
suppress_token_mask = torch.isin(vocab_tensor, suppress_tokens)
self.suppress_tokens = self.suppress_tokens.to(scores.device)
suppress_token_mask = torch.isin(vocab_tensor, self.suppress_tokens)
scores = torch.where(suppress_token_mask, -float("inf"), scores)
return scores

Expand Down Expand Up @@ -2191,23 +2209,30 @@ class BarkEosPrioritizerLogitsProcessor(LogitsProcessor):
</Tip>

Args:
eos_token_id (`Union[int, List[int]]`):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
eos_token_id (`Union[int, List[int], torch.Tensor]`):
The id(s) of the *end-of-sequence* token.
min_eos_p (`float`, *optional*):
Minimum end of speech threshold.
"""

def __init__(self, eos_token_id: Union[int, List[int]], min_eos_p: float):
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
def __init__(self, eos_token_id: Union[int, List[int], torch.Tensor], min_eos_p: float):
if not isinstance(eos_token_id, torch.Tensor):
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id = torch.tensor(eos_token_id)
self.eos_token_id = eos_token_id

if torch.is_floating_point(eos_token_id) or (eos_token_id < 0).any():
raise ValueError(f"`eos_token_id` has to be a list of positive integers, but is {eos_token_id}")

if min_eos_p is not None and min_eos_p <= 0:
raise ValueError(f"`min_eos_p` has to be a positive float, but is {min_eos_p}")
self.min_eos_p = min_eos_p

@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
scores_processed = scores
self.eos_token_id = self.eos_token_id.to(scores.device)
if self.min_eos_p:
probs = torch.nn.functional.softmax(scores.float(), dim=-1)
# create scores full of -inf except for the eos_token_id
Expand Down
Loading
Loading