Skip to content

Commit

Permalink
update PR
Browse files Browse the repository at this point in the history
  • Loading branch information
zucchini-nlp committed Mar 26, 2024
1 parent a385c6d commit 7c00bb1
Showing 1 changed file with 23 additions and 16 deletions.
39 changes: 23 additions & 16 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1915,8 +1915,8 @@ def _contrastive_search(
if eos_token_id is not None:
logger.warning_once(
"`eos_token_id` is deprecated in this function and will be removed in v4.41, use"
" `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead.",
" Otherwise make sure to set `model.generation_config.eos_token_id`"
" `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead."
" Otherwise make sure to set `model.generation_config.eos_token_id`",
FutureWarning,
)
stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id))
Expand Down Expand Up @@ -2373,9 +2373,10 @@ def _greedy_search(
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
if eos_token_id is not None:
warnings.warn(
logger.warning_once(
"`eos_token_id` is deprecated in this function and will be removed in v4.41, use"
" `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead.",
" `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead."
" Otherwise make sure to set `model.generation_config.eos_token_id`",
FutureWarning,
)
stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id))
Expand Down Expand Up @@ -2668,9 +2669,10 @@ def _sample(
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
if eos_token_id is not None:
warnings.warn(
logger.warning_once(
"`eos_token_id` is deprecated in this function and will be removed in v4.41, use"
" `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead.",
" `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead."
" Otherwise make sure to set `model.generation_config.eos_token_id`",
FutureWarning,
)
stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id))
Expand Down Expand Up @@ -2994,9 +2996,10 @@ def _beam_search(
warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning)
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
if eos_token_id is not None:
warnings.warn(
logger.warning_once(
"`eos_token_id` is deprecated in this function and will be removed in v4.41, use"
" `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead.",
" `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead."
" Otherwise make sure to set `model.generation_config.eos_token_id`",
FutureWarning,
)
stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id))
Expand Down Expand Up @@ -3396,9 +3399,10 @@ def _beam_sample(
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
if eos_token_id is not None:
warnings.warn(
logger.warning_once(
"`eos_token_id` is deprecated in this function and will be removed in v4.41, use"
" `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead.",
" `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead."
" Otherwise make sure to set `model.generation_config.eos_token_id`",
FutureWarning,
)
stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id))
Expand Down Expand Up @@ -3750,9 +3754,10 @@ def _group_beam_search(
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
if eos_token_id is not None:
warnings.warn(
logger.warning_once(
"`eos_token_id` is deprecated in this function and will be removed in v4.41, use"
" `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead.",
" `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead."
" Otherwise make sure to set `model.generation_config.eos_token_id`",
FutureWarning,
)
stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id))
Expand Down Expand Up @@ -4168,9 +4173,10 @@ def _constrained_beam_search(
warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning)
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
if eos_token_id is not None:
warnings.warn(
logger.warning_once(
"`eos_token_id` is deprecated in this function and will be removed in v4.41, use"
" `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead.",
" `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead."
" Otherwise make sure to set `model.generation_config.eos_token_id`",
FutureWarning,
)
stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id))
Expand Down Expand Up @@ -4517,9 +4523,10 @@ def _assisted_decoding(
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
if eos_token_id is not None:
warnings.warn(
logger.warning_once(
"`eos_token_id` is deprecated in this function and will be removed in v4.41, use"
" `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead.",
" `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead."
" Otherwise make sure to set `model.generation_config.eos_token_id`",
FutureWarning,
)
stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id))
Expand Down

0 comments on commit 7c00bb1

Please sign in to comment.