From c672cd7a215491cfe1469defcf76eee351e9718a Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Tue, 2 Jul 2024 13:08:07 +0000 Subject: [PATCH 1/2] rely on the tokenizer default kwargs --- src/transformers/pipelines/text_generation.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/src/transformers/pipelines/text_generation.py b/src/transformers/pipelines/text_generation.py index c2dce89dd701..37d239b397d2 100644 --- a/src/transformers/pipelines/text_generation.py +++ b/src/transformers/pipelines/text_generation.py @@ -266,9 +266,9 @@ def preprocess( prompt_text, prefix="", handle_long_generation=None, - add_special_tokens=False, + add_special_tokens=None, truncation=None, - padding=False, + padding=None, max_length=None, **generate_kwargs, ): @@ -283,14 +283,13 @@ def preprocess( return_tensors=self.framework, ) else: - inputs = self.tokenizer( - prefix + prompt_text, - truncation=truncation, - padding=padding, - max_length=max_length, - add_special_tokens=add_special_tokens, - return_tensors=self.framework, - ) + # Only set non-None tokenizer kwargs, so as to rely on the tokenizer's defaults + tokenizer_kwargs = {} + for tokenizer_kwarg_name in ["add_special_tokens", "truncation", "padding", "max_length"]: + if locals()[tokenizer_kwarg_name] is not None: + tokenizer_kwargs[tokenizer_kwarg_name] = locals()[tokenizer_kwarg_name] + inputs = self.tokenizer(prefix + prompt_text, return_tensors=self.framework, **tokenizer_kwargs) + inputs["prompt_text"] = prompt_text if handle_long_generation == "hole": From cad6a4006d826c3b925841a6170318d9c4b820b6 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Tue, 2 Jul 2024 13:40:01 +0000 Subject: [PATCH 2/2] fix a few tests --- src/transformers/pipelines/text_generation.py | 9 ++++++--- tests/generation/test_utils.py | 13 ++++--------- tests/pipelines/test_pipelines_text_generation.py | 2 +- 3 files changed, 11 insertions(+), 13 deletions(-) diff --git a/src/transformers/pipelines/text_generation.py b/src/transformers/pipelines/text_generation.py index 37d239b397d2..994a51748591 100644 --- a/src/transformers/pipelines/text_generation.py +++ b/src/transformers/pipelines/text_generation.py @@ -273,14 +273,17 @@ def preprocess( **generate_kwargs, ): if isinstance(prompt_text, Chat): + # Only set non-None tokenizer kwargs, so as to rely on the tokenizer's defaults + tokenizer_kwargs = {} + for tokenizer_kwarg_name in ["truncation", "padding", "max_length"]: + if locals()[tokenizer_kwarg_name] is not None: + tokenizer_kwargs[tokenizer_kwarg_name] = locals()[tokenizer_kwarg_name] inputs = self.tokenizer.apply_chat_template( prompt_text.messages, - truncation=truncation, - padding=padding, - max_length=max_length, add_generation_prompt=True, return_dict=True, return_tensors=self.framework, + **tokenizer_kwargs, ) else: # Only set non-None tokenizer kwargs, so as to rely on the tokenizer's defaults diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 3293cc279d01..1f775345b40c 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -2082,6 +2082,7 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa [1, 18], ) + # TODO (joao): replace `stop_sequence` in the pipeline by the more recent `generate` functionality def test_stop_sequence_stopping_criteria(self): # PT-only test: TF doesn't have StoppingCriteria prompt = """Hello I believe in""" @@ -2089,17 +2090,11 @@ def test_stop_sequence_stopping_criteria(self): output = generator(prompt) self.assertEqual( output, - [ - { - "generated_text": ( - "Hello I believe in in in number number number number number number number number number" - ) - } - ], + [{"generated_text": ("Hello I believe in we we we we we we we we we")}], ) - output = generator(prompt, stop_sequence=" number") - self.assertEqual(output, [{"generated_text": "Hello I believe in in in number"}]) + output = generator(prompt, stop_sequence=" we") + self.assertEqual(output, [{"generated_text": "Hello I believe in we"}]) def test_generate_non_nlp_input_ids_as_kwarg(self): # PT-only test: AFAIK there's no non-NLP model architecture in TF that supports `input_ids` as its only input diff --git a/tests/pipelines/test_pipelines_text_generation.py b/tests/pipelines/test_pipelines_text_generation.py index 00ddd77f8260..695befe32928 100644 --- a/tests/pipelines/test_pipelines_text_generation.py +++ b/tests/pipelines/test_pipelines_text_generation.py @@ -398,7 +398,7 @@ def run_pipeline_test(self, text_generator, _): self.assertEqual(outputs, [{"generated_text": ANY(str)}]) else: with self.assertRaises((ValueError, AssertionError)): - outputs = text_generator("") + outputs = text_generator("", add_special_tokens=False) if text_generator.framework == "tf": # TF generation does not support max_new_tokens, and it's impossible