Skip to content

Commit

Permalink
Generation: fix handling of special tokens (#31254)
Browse files Browse the repository at this point in the history
* fix special tokens in generatioon

* fix test

* add warning

* fix the check

* warn once

* fix
  • Loading branch information
zucchini-nlp authored Jun 6, 2024
1 parent 7729b77 commit 5fabd1e
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 30 deletions.
55 changes: 27 additions & 28 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1436,23 +1436,6 @@ def _get_cache(self, cache_implementation: str, max_batch_size: int, max_cache_l
self._cache.reset()
return self._cache

def _get_decoder_start_token_id(
self, decoder_start_token_id: Union[int, List[int]] = None, bos_token_id: int = None
) -> int:
decoder_start_token_id = (
decoder_start_token_id
if decoder_start_token_id is not None
else self.generation_config.decoder_start_token_id
)
bos_token_id = bos_token_id if bos_token_id is not None else self.generation_config.bos_token_id

if decoder_start_token_id is not None:
return decoder_start_token_id
elif bos_token_id is not None:
return bos_token_id
else:
return

def _supports_default_dynamic_cache(self) -> bool:
"""
Return `True` if current model can use a `DynamicCache` instance when initializing the `past_key_values`.
Expand All @@ -1478,25 +1461,32 @@ def _prepare_special_tokens(
function). However, if called outside `generate`, consider creating a copy of `generation_config` first.
"""

# Convert special tokens to tensors (if they exist)
def _tensor_or_none(token, device=None):
# Convert special tokens to tensors (if they exist either in kwargs or in self.config)
def _tensor_or_none(token_kwargs, token_self, device=None):
if device is None:
device = self.device

token = token_kwargs if token_kwargs is not None else token_self
if token is None or isinstance(token, torch.Tensor):
return token
return torch.tensor(token, device=device, dtype=torch.long)

# for BC we also try to get `decoder_start_token_id` from model's generation config (#30892)
if self.config.is_encoder_decoder:
generation_config.decoder_start_token_id = self._get_decoder_start_token_id(
generation_config.decoder_start_token_id, generation_config.bos_token_id
)
bos_token_id = _tensor_or_none(
generation_config.bos_token_id, self.generation_config.bos_token_id, device=device
)
eos_token_id = _tensor_or_none(
generation_config.eos_token_id, self.generation_config.eos_token_id, device=device
)
pad_token_id = _tensor_or_none(
generation_config.pad_token_id, self.generation_config.pad_token_id, device=device
)
decoder_start_token_id = _tensor_or_none(
generation_config.decoder_start_token_id, self.generation_config.decoder_start_token_id, device=device
)

bos_token_id = _tensor_or_none(generation_config.bos_token_id, device=device)
eos_token_id = _tensor_or_none(generation_config.eos_token_id, device=device)
pad_token_id = _tensor_or_none(generation_config.pad_token_id, device=device)
decoder_start_token_id = _tensor_or_none(generation_config.decoder_start_token_id, device=device)
# for BC we also try to get `decoder_start_token_id` or `bos_token_id` (#30892)
if self.config.is_encoder_decoder:
decoder_start_token_id = decoder_start_token_id if decoder_start_token_id is not None else bos_token_id

# We can have more than one eos token. Always treat it as a 1D tensor (when it exists).
if eos_token_id is not None and eos_token_id.ndim == 0:
Expand All @@ -1512,6 +1502,15 @@ def _tensor_or_none(token, device=None):
pad_token_id = eos_token_id[0]
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_id} for open-end generation.")

# we can't infer attn mask if pad token is set to be eos token in model's generation config
if eos_token_id is not None and torch.isin(elements=eos_token_id, test_elements=pad_token_id).any():
if kwargs_has_attention_mask is not None and not kwargs_has_attention_mask:
logger.warning_once(
"The attention mask is not set and cannot be inferred from input because pad token is same as eos token."
"As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` "
"to obtain reliable results."
)

# Sanity checks/warnings
if self.config.is_encoder_decoder and decoder_start_token_id is None:
raise ValueError(
Expand Down
4 changes: 2 additions & 2 deletions tests/generation/test_framework_agnostic.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def test_transition_scores_greedy_search(self):
tokenizer.pad_token = tokenizer.eos_token

model = model_cls.from_pretrained("distilbert/distilgpt2")
model.generation_config.eos_token_id = None
input_ids = tokenizer(articles, return_tensors=return_tensors, padding=True).input_ids
if is_pt:
model = model.to(torch_device)
Expand All @@ -170,7 +171,6 @@ def test_transition_scores_greedy_search(self):
input_ids=input_ids,
max_new_tokens=5,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=None,
return_dict_in_generate=True,
output_scores=True,
)
Expand All @@ -197,6 +197,7 @@ def test_transition_scores_greedy_search_normalized(self):
tokenizer.pad_token = tokenizer.eos_token

model = model_cls.from_pretrained("distilbert/distilgpt2")
model.generation_config.eos_token_id = None
input_ids = tokenizer(articles, return_tensors=return_tensors, padding=True).input_ids
if is_pt:
model = model.to(torch_device)
Expand All @@ -206,7 +207,6 @@ def test_transition_scores_greedy_search_normalized(self):
input_ids=input_ids,
max_new_tokens=5,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=None,
return_dict_in_generate=True,
output_scores=True,
)
Expand Down

0 comments on commit 5fabd1e

Please sign in to comment.