Skip to content

Commit

Permalink
Enable TruncationStrategy override for pipelines (huggingface#9432)
Browse files Browse the repository at this point in the history
* Enable TruncationStrategy override for pipelines

* Update isort.

* Fixing test

* Fixing text_generation pipeline.

* Using same DummyTok as other PR  for easier merge later.

* Some more import guards.

* Remove bogus file.

* Do not pass `generate_kwargs` to `_parse_and_tokenize`.
@patrickvonplaten

* Removed DummyTok.

* Doc quality.
  • Loading branch information
Narsil authored Jan 11, 2021
1 parent 8d25df2 commit d20e9c7
Show file tree
Hide file tree
Showing 6 changed files with 95 additions and 26 deletions.
7 changes: 5 additions & 2 deletions src/transformers/pipelines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from ..file_utils import add_end_docstrings, is_tf_available, is_torch_available
from ..modelcard import ModelCard
from ..tokenization_utils import PreTrainedTokenizer
from ..tokenization_utils import PreTrainedTokenizer, TruncationStrategy
from ..utils import logging


Expand Down Expand Up @@ -577,7 +577,9 @@ def check_model_type(self, supported_models: Union[List[str], dict]):
f"The model '{self.model.__class__.__name__}' is not supported for {self.task}. Supported models are {supported_models}",
)

def _parse_and_tokenize(self, inputs, padding=True, add_special_tokens=True, **kwargs):
def _parse_and_tokenize(
self, inputs, padding=True, add_special_tokens=True, truncation=TruncationStrategy.DO_NOT_TRUNCATE, **kwargs
):
"""
Parse arguments and tokenize
"""
Expand All @@ -587,6 +589,7 @@ def _parse_and_tokenize(self, inputs, padding=True, add_special_tokens=True, **k
add_special_tokens=add_special_tokens,
return_tensors=self.framework,
padding=padding,
truncation=truncation,
)

return inputs
Expand Down
7 changes: 5 additions & 2 deletions src/transformers/pipelines/conversational.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import List, Optional, Union

from ..file_utils import add_end_docstrings, is_tf_available, is_torch_available
from ..tokenization_utils import TruncationStrategy
from ..utils import logging
from .base import PIPELINE_INIT_ARGS, Pipeline

Expand Down Expand Up @@ -317,12 +318,14 @@ def __call__(
else:
return output

def _parse_and_tokenize(self, inputs, **kwargs):
def _parse_and_tokenize(
self, inputs, add_special_tokens=False, padding=False, truncation=TruncationStrategy.DO_NOT_TRUNCATE, **kwargs
):
"""
Parse arguments and tokenize, adding an EOS token at the end of the user input
"""
# Parse arguments
inputs = self.tokenizer(inputs, add_special_tokens=False, padding=False).get("input_ids", [])
inputs = self.tokenizer(inputs, add_special_tokens=add_special_tokens, padding=padding).get("input_ids", [])
for input in inputs:
input.append(self.tokenizer.eos_token_id)
return inputs
Expand Down
18 changes: 13 additions & 5 deletions src/transformers/pipelines/text2text_generation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from ..file_utils import add_end_docstrings, is_tf_available, is_torch_available
from ..tokenization_utils import TruncationStrategy
from ..utils import logging
from .base import PIPELINE_INIT_ARGS, Pipeline

Expand Down Expand Up @@ -50,7 +51,13 @@ def check_inputs(self, input_length: int, min_length: int, max_length: int):
return True

def __call__(
self, *args, return_tensors=False, return_text=True, clean_up_tokenization_spaces=False, **generate_kwargs
self,
*args,
return_tensors=False,
return_text=True,
clean_up_tokenization_spaces=False,
truncation=TruncationStrategy.DO_NOT_TRUNCATE,
**generate_kwargs
):
r"""
Generate the output text(s) using text(s) given as inputs.
Expand All @@ -64,6 +71,10 @@ def __call__(
Whether or not to include the decoded texts in the outputs.
clean_up_tokenization_spaces (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to clean up the potential extra spaces in the text output.
truncation (:obj:`TruncationStrategy`, `optional`, defaults to :obj:`TruncationStrategy.DO_NOT_TRUNCATE`):
The truncation strategy for the tokenization within the pipeline.
:obj:`TruncationStrategy.DO_NOT_TRUNCATE` (default) will never truncate, but it is sometimes desirable
to truncate the input to fit the model's max_length instead of throwing an error down the line.
generate_kwargs:
Additional keyword arguments to pass along to the generate method of the model (see the generate method
corresponding to your framework `here <./model.html#generative-models>`__).
Expand Down Expand Up @@ -96,7 +107,7 @@ def __call__(
)

with self.device_placement():
inputs = self._parse_and_tokenize(*args, padding=padding, **generate_kwargs)
inputs = self._parse_and_tokenize(*args, padding=padding, truncation=truncation)

if self.framework == "pt":
inputs = self.ensure_tensor_on_device(**inputs)
Expand All @@ -108,9 +119,6 @@ def __call__(
max_length = generate_kwargs.get("max_length", self.model.config.max_length)
self.check_inputs(input_length, min_length, max_length)

# truncation should be used by _parse_and_tokenize
generate_kwargs.pop("truncation", None)

generations = self.model.generate(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
Expand Down
18 changes: 4 additions & 14 deletions src/transformers/pipelines/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,25 +50,15 @@ def __init__(self, *args, **kwargs):
self.check_model_type(self.ALLOWED_MODELS)

# overriding _parse_and_tokenize to allow for unusual language-modeling tokenizer arguments

def _parse_and_tokenize(self, inputs, padding=True, add_special_tokens=True, **kwargs):
def _parse_and_tokenize(self, *args, **kwargs):
"""
Parse arguments and tokenize
"""
# Parse arguments
if self.model.__class__.__name__ in ["TransfoXLLMHeadModel"]:
tokenizer_kwargs = {"add_space_before_punct_symbol": True}
else:
tokenizer_kwargs = {}
inputs = self.tokenizer(
inputs,
add_special_tokens=add_special_tokens,
return_tensors=self.framework,
padding=padding,
**tokenizer_kwargs,
)

return inputs
kwargs.update({"add_space_before_punct_symbol": True})

return super()._parse_and_tokenize(*args, **kwargs)

def __call__(
self,
Expand Down
12 changes: 10 additions & 2 deletions src/transformers/pipelines/zero_shot_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np

from ..file_utils import add_end_docstrings
from ..tokenization_utils import TruncationStrategy
from ..utils import logging
from .base import PIPELINE_INIT_ARGS, ArgumentHandler, Pipeline

Expand Down Expand Up @@ -78,7 +79,14 @@ def entailment_id(self):
return -1

def _parse_and_tokenize(
self, sequences, candidate_labels, hypothesis_template, padding=True, add_special_tokens=True, **kwargs
self,
sequences,
candidate_labels,
hypothesis_template,
padding=True,
add_special_tokens=True,
truncation=TruncationStrategy.ONLY_FIRST,
**kwargs
):
"""
Parse arguments and tokenize only_first so that hypothesis (label) is not truncated
Expand All @@ -89,7 +97,7 @@ def _parse_and_tokenize(
add_special_tokens=add_special_tokens,
return_tensors=self.framework,
padding=padding,
truncation="only_first",
truncation=truncation,
)

return inputs
Expand Down
59 changes: 58 additions & 1 deletion tests/test_pipelines_summarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,72 @@

import unittest

from transformers import pipeline
from transformers import AutoTokenizer, is_torch_available, pipeline
from transformers.testing_utils import require_torch, slow, torch_device
from transformers.tokenization_utils import TruncationStrategy

from .test_pipelines_common import MonoInputPipelineCommonMixin


if is_torch_available():
import torch

from transformers.models.bart import BartConfig, BartForConditionalGeneration

DEFAULT_DEVICE_NUM = -1 if torch_device == "cpu" else 0


class SimpleSummarizationPipelineTests(unittest.TestCase):
@require_torch
def test_input_too_long(self):
torch.manual_seed(0)
config = BartConfig(
vocab_size=257,
d_model=32,
encoder_layers=1,
decoder_layers=1,
encoder_ffn_dim=32,
decoder_ffn_dim=32,
# So any text > 4 should raise an exception
max_position_embeddings=4,
encoder_attention_heads=1,
decoder_attention_heads=1,
max_length=4,
min_length=1,
)
model = BartForConditionalGeneration(config)
# Bias output towards L
V, C = model.lm_head.weight.shape

bias = torch.zeros(V, requires_grad=True)
bias[76] = 10

model.lm_head.bias = torch.nn.Parameter(bias)

# # Generated with:
# import tempfile
# from tokenizers import Tokenizer, models
# from transformers import PreTrainedTokenizerFast
# model_max_length = 4
# vocab = [(chr(i), i) for i in range(256)]
# tokenizer = Tokenizer(models.Unigram(vocab))
# with tempfile.NamedTemporaryFile() as f:
# tokenizer.save(f.name)
# real_tokenizer = PreTrainedTokenizerFast(tokenizer_file=f.name, model_max_length=model_max_length)
# real_tokenizer._tokenizer.save("tokenizer.json")
# # + add missing config.json with albert as model_type
tokenizer = AutoTokenizer.from_pretrained("Narsil/small_summarization_test")
nlp = pipeline(task="summarization", model=model, tokenizer=tokenizer)

with self.assertLogs("transformers", level="WARNING"):
with self.assertRaises(IndexError):
_ = nlp("This is a test")

output = nlp("This is a test", truncation=TruncationStrategy.ONLY_FIRST)
# 2 is default BOS from Bart.
self.assertEqual(output, [{"summary_text": "\x02 L L L"}])


class SummarizationPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
pipeline_task = "summarization"
pipeline_running_kwargs = {"num_beams": 2, "min_length": 2, "max_length": 5}
Expand Down

0 comments on commit d20e9c7

Please sign in to comment.