From 7c00bb18126212d7ff835cfc8979917204d93fd8 Mon Sep 17 00:00:00 2001 From: raushan Date: Tue, 26 Mar 2024 15:48:50 +0100 Subject: [PATCH] update PR --- src/transformers/generation/utils.py | 39 ++++++++++++++++------------ 1 file changed, 23 insertions(+), 16 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 924844ddea76..a958c8c86a92 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1915,8 +1915,8 @@ def _contrastive_search( if eos_token_id is not None: logger.warning_once( "`eos_token_id` is deprecated in this function and will be removed in v4.41, use" - " `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead.", - " Otherwise make sure to set `model.generation_config.eos_token_id`" + " `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead." + " Otherwise make sure to set `model.generation_config.eos_token_id`", FutureWarning, ) stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) @@ -2373,9 +2373,10 @@ def _greedy_search( stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id if eos_token_id is not None: - warnings.warn( + logger.warning_once( "`eos_token_id` is deprecated in this function and will be removed in v4.41, use" - " `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead.", + " `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead." + " Otherwise make sure to set `model.generation_config.eos_token_id`", FutureWarning, ) stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) @@ -2668,9 +2669,10 @@ def _sample( logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id if eos_token_id is not None: - warnings.warn( + logger.warning_once( "`eos_token_id` is deprecated in this function and will be removed in v4.41, use" - " `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead.", + " `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead." + " Otherwise make sure to set `model.generation_config.eos_token_id`", FutureWarning, ) stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) @@ -2994,9 +2996,10 @@ def _beam_search( warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning) pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id if eos_token_id is not None: - warnings.warn( + logger.warning_once( "`eos_token_id` is deprecated in this function and will be removed in v4.41, use" - " `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead.", + " `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead." + " Otherwise make sure to set `model.generation_config.eos_token_id`", FutureWarning, ) stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) @@ -3396,9 +3399,10 @@ def _beam_sample( stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id if eos_token_id is not None: - warnings.warn( + logger.warning_once( "`eos_token_id` is deprecated in this function and will be removed in v4.41, use" - " `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead.", + " `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead." + " Otherwise make sure to set `model.generation_config.eos_token_id`", FutureWarning, ) stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) @@ -3750,9 +3754,10 @@ def _group_beam_search( stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id if eos_token_id is not None: - warnings.warn( + logger.warning_once( "`eos_token_id` is deprecated in this function and will be removed in v4.41, use" - " `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead.", + " `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead." + " Otherwise make sure to set `model.generation_config.eos_token_id`", FutureWarning, ) stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) @@ -4168,9 +4173,10 @@ def _constrained_beam_search( warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning) pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id if eos_token_id is not None: - warnings.warn( + logger.warning_once( "`eos_token_id` is deprecated in this function and will be removed in v4.41, use" - " `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead.", + " `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead." + " Otherwise make sure to set `model.generation_config.eos_token_id`", FutureWarning, ) stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) @@ -4517,9 +4523,10 @@ def _assisted_decoding( stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id if eos_token_id is not None: - warnings.warn( + logger.warning_once( "`eos_token_id` is deprecated in this function and will be removed in v4.41, use" - " `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead.", + " `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead." + " Otherwise make sure to set `model.generation_config.eos_token_id`", FutureWarning, ) stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id))