diff --git a/src/transformers/pipelines/text_generation.py b/src/transformers/pipelines/text_generation.py index c2dce89dd701..994a51748591 100644 --- a/src/transformers/pipelines/text_generation.py +++ b/src/transformers/pipelines/text_generation.py @@ -266,31 +266,33 @@ def preprocess( prompt_text, prefix="", handle_long_generation=None, - add_special_tokens=False, + add_special_tokens=None, truncation=None, - padding=False, + padding=None, max_length=None, **generate_kwargs, ): if isinstance(prompt_text, Chat): + # Only set non-None tokenizer kwargs, so as to rely on the tokenizer's defaults + tokenizer_kwargs = {} + for tokenizer_kwarg_name in ["truncation", "padding", "max_length"]: + if locals()[tokenizer_kwarg_name] is not None: + tokenizer_kwargs[tokenizer_kwarg_name] = locals()[tokenizer_kwarg_name] inputs = self.tokenizer.apply_chat_template( prompt_text.messages, - truncation=truncation, - padding=padding, - max_length=max_length, add_generation_prompt=True, return_dict=True, return_tensors=self.framework, + **tokenizer_kwargs, ) else: - inputs = self.tokenizer( - prefix + prompt_text, - truncation=truncation, - padding=padding, - max_length=max_length, - add_special_tokens=add_special_tokens, - return_tensors=self.framework, - ) + # Only set non-None tokenizer kwargs, so as to rely on the tokenizer's defaults + tokenizer_kwargs = {} + for tokenizer_kwarg_name in ["add_special_tokens", "truncation", "padding", "max_length"]: + if locals()[tokenizer_kwarg_name] is not None: + tokenizer_kwargs[tokenizer_kwarg_name] = locals()[tokenizer_kwarg_name] + inputs = self.tokenizer(prefix + prompt_text, return_tensors=self.framework, **tokenizer_kwargs) + inputs["prompt_text"] = prompt_text if handle_long_generation == "hole": diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 3293cc279d01..1f775345b40c 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -2082,6 +2082,7 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa [1, 18], ) + # TODO (joao): replace `stop_sequence` in the pipeline by the more recent `generate` functionality def test_stop_sequence_stopping_criteria(self): # PT-only test: TF doesn't have StoppingCriteria prompt = """Hello I believe in""" @@ -2089,17 +2090,11 @@ def test_stop_sequence_stopping_criteria(self): output = generator(prompt) self.assertEqual( output, - [ - { - "generated_text": ( - "Hello I believe in in in number number number number number number number number number" - ) - } - ], + [{"generated_text": ("Hello I believe in we we we we we we we we we")}], ) - output = generator(prompt, stop_sequence=" number") - self.assertEqual(output, [{"generated_text": "Hello I believe in in in number"}]) + output = generator(prompt, stop_sequence=" we") + self.assertEqual(output, [{"generated_text": "Hello I believe in we"}]) def test_generate_non_nlp_input_ids_as_kwarg(self): # PT-only test: AFAIK there's no non-NLP model architecture in TF that supports `input_ids` as its only input diff --git a/tests/pipelines/test_pipelines_text_generation.py b/tests/pipelines/test_pipelines_text_generation.py index 00ddd77f8260..695befe32928 100644 --- a/tests/pipelines/test_pipelines_text_generation.py +++ b/tests/pipelines/test_pipelines_text_generation.py @@ -398,7 +398,7 @@ def run_pipeline_test(self, text_generator, _): self.assertEqual(outputs, [{"generated_text": ANY(str)}]) else: with self.assertRaises((ValueError, AssertionError)): - outputs = text_generator("") + outputs = text_generator("", add_special_tokens=False) if text_generator.framework == "tf": # TF generation does not support max_new_tokens, and it's impossible