Skip to content

Commit

Permalink
Allow FP16 or other precision inference for Pipelines (#31342)
Browse files Browse the repository at this point in the history
* cast image features to model.dtype where needed to support FP16 or other precision in pipelines

* Update src/transformers/pipelines/image_feature_extraction.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Use .to instead

* Add FP16 pipeline support for zeroshot audio classification

* Remove unused torch imports

* Add docs on FP16 pipeline

* Remove unused import

* Add FP16 tests to pipeline mixin

* Add fp16 placeholder for mask_generation pipeline test

* Add FP16 tests for all pipelines

* Fix formatting

* Remove torch_dtype arg from is_pipeline_test_to_skip*

* Fix format

* trigger ci

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
  • Loading branch information
aliencaocao and amyeroberts authored Jul 5, 2024
1 parent e786844 commit ac26260
Show file tree
Hide file tree
Showing 45 changed files with 354 additions and 79 deletions.
5 changes: 5 additions & 0 deletions docs/source/en/main_classes/pipelines.md
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,11 @@ This is a simplified view, since the pipeline can handle automatically the batch
about how many forward passes you inputs are actually going to trigger, you can optimize the `batch_size`
independently of the inputs. The caveats from the previous section still apply.

## Pipeline FP16 inference
Models can be run in FP16 which can be significantly faster on GPU while saving memory. Most models will not suffer noticeable performance loss from this. The larger the model, the less likely that it will.

To enable FP16 inference, you can simply pass `torch_dtype=torch.float16` or `torch_dtype='float16'` to the pipeline constructor. Note that this only works for models with a PyTorch backend. Your inputs will be converted to FP16 internally.

## Pipeline custom code

If you want to override a specific pipeline.
Expand Down
5 changes: 3 additions & 2 deletions docs/source/en/pipeline_tutorial.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,9 @@ This will work regardless of whether you are using PyTorch or Tensorflow.
transcriber = pipeline(model="openai/whisper-large-v2", device=0)
```

If the model is too large for a single GPU and you are using PyTorch, you can set `device_map="auto"` to automatically
If the model is too large for a single GPU and you are using PyTorch, you can set `torch_dtype='float16'` to enable FP16 precision inference. Usually this would not cause significant performance drops but make sure you evaluate it on your models!

Alternatively, you can set `device_map="auto"` to automatically
determine how to load and store the model weights. Using the `device_map` argument requires the 🤗 [Accelerate](https://huggingface.co/docs/accelerate)
package:

Expand Down Expand Up @@ -342,4 +344,3 @@ gr.Interface.from_pipeline(pipe).launch()

By default, the web demo runs on a local server. If you'd like to share it with others, you can generate a temporary public
link by setting `share=True` in `launch()`. You can also host your demo on [Hugging Face Spaces](https://huggingface.co/spaces) for a permanent link.

2 changes: 2 additions & 0 deletions src/transformers/pipelines/depth_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ def preprocess(self, image, timeout=None):
image = load_image(image, timeout)
self.image_size = image.size
model_inputs = self.image_processor(images=image, return_tensors=self.framework)
if self.framework == "pt":
model_inputs = model_inputs.to(self.torch_dtype)
return model_inputs

def _forward(self, model_inputs):
Expand Down
5 changes: 4 additions & 1 deletion src/transformers/pipelines/document_question_answering.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,10 @@ def preprocess(
if input.get("image", None) is not None:
image = load_image(input["image"], timeout=timeout)
if self.image_processor is not None:
image_features.update(self.image_processor(images=image, return_tensors=self.framework))
image_inputs = self.image_processor(images=image, return_tensors=self.framework)
if self.framework == "pt":
image_inputs = image_inputs.to(self.torch_dtype)
image_features.update(image_inputs)
elif self.feature_extractor is not None:
image_features.update(self.feature_extractor(images=image, return_tensors=self.framework))
elif self.model_type == ModelType.VisionEncoderDecoder:
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/pipelines/image_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,8 @@ def __call__(self, images: Union[str, List[str], "Image.Image", List["Image.Imag
def preprocess(self, image, timeout=None):
image = load_image(image, timeout=timeout)
model_inputs = self.image_processor(images=image, return_tensors=self.framework)
if self.framework == "pt":
model_inputs = model_inputs.to(self.torch_dtype)
return model_inputs

def _forward(self, model_inputs):
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/pipelines/image_feature_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ def _sanitize_parameters(self, image_processor_kwargs=None, return_tensors=None,
def preprocess(self, image, timeout=None, **image_processor_kwargs) -> Dict[str, GenericTensor]:
image = load_image(image, timeout=timeout)
model_inputs = self.image_processor(image, return_tensors=self.framework, **image_processor_kwargs)
if self.framework == "pt":
model_inputs = model_inputs.to(self.torch_dtype)
return model_inputs

def _forward(self, model_inputs):
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/pipelines/image_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@ def preprocess(self, image, subtask=None, timeout=None):
else:
kwargs = {"task_inputs": [subtask]}
inputs = self.image_processor(images=[image], return_tensors="pt", **kwargs)
if self.framework == "pt":
inputs = inputs.to(self.torch_dtype)
inputs["task_inputs"] = self.tokenizer(
inputs["task_inputs"],
padding="max_length",
Expand All @@ -155,6 +157,8 @@ def preprocess(self, image, subtask=None, timeout=None):
)["input_ids"]
else:
inputs = self.image_processor(images=[image], return_tensors="pt")
if self.framework == "pt":
inputs = inputs.to(self.torch_dtype)
inputs["target_size"] = target_size
return inputs

Expand Down
2 changes: 2 additions & 0 deletions src/transformers/pipelines/image_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ def _forward(self, model_inputs):
def preprocess(self, image, timeout=None):
image = load_image(image, timeout=timeout)
inputs = self.image_processor(images=[image], return_tensors="pt")
if self.framework == "pt":
inputs = inputs.to(self.torch_dtype)
return inputs

def postprocess(self, model_outputs):
Expand Down
8 changes: 8 additions & 0 deletions src/transformers/pipelines/image_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,17 +138,23 @@ def preprocess(self, image, prompt=None, timeout=None):

if model_type == "git":
model_inputs = self.image_processor(images=image, return_tensors=self.framework)
if self.framework == "pt":
model_inputs = model_inputs.to(self.torch_dtype)
input_ids = self.tokenizer(text=prompt, add_special_tokens=False).input_ids
input_ids = [self.tokenizer.cls_token_id] + input_ids
input_ids = torch.tensor(input_ids).unsqueeze(0)
model_inputs.update({"input_ids": input_ids})

elif model_type == "pix2struct":
model_inputs = self.image_processor(images=image, header_text=prompt, return_tensors=self.framework)
if self.framework == "pt":
model_inputs = model_inputs.to(self.torch_dtype)

elif model_type != "vision-encoder-decoder":
# vision-encoder-decoder does not support conditional generation
model_inputs = self.image_processor(images=image, return_tensors=self.framework)
if self.framework == "pt":
model_inputs = model_inputs.to(self.torch_dtype)
text_inputs = self.tokenizer(prompt, return_tensors=self.framework)
model_inputs.update(text_inputs)

Expand All @@ -157,6 +163,8 @@ def preprocess(self, image, prompt=None, timeout=None):

else:
model_inputs = self.image_processor(images=image, return_tensors=self.framework)
if self.framework == "pt":
model_inputs = model_inputs.to(self.torch_dtype)

if self.model.config.model_type == "git" and prompt is None:
model_inputs["input_ids"] = None
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/pipelines/mask_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,8 @@ def preprocess(
image, target_size, crops_n_layers, crop_overlap_ratio, points_per_crop, crop_n_points_downscale_factor
)
model_inputs = self.image_processor(images=cropped_images, return_tensors="pt")
if self.framework == "pt":
model_inputs = model_inputs.to(self.torch_dtype)

with self.device_placement():
if self.framework == "pt":
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/pipelines/object_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ def preprocess(self, image, timeout=None):
image = load_image(image, timeout=timeout)
target_size = torch.IntTensor([[image.height, image.width]])
inputs = self.image_processor(images=[image], return_tensors="pt")
if self.framework == "pt":
inputs = inputs.to(self.torch_dtype)
if self.tokenizer is not None:
inputs = self.tokenizer(text=inputs["words"], boxes=inputs["boxes"], return_tensors="pt")
inputs["target_size"] = target_size
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/pipelines/video_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ def preprocess(self, video, num_frames=None, frame_sampling_rate=1):
video = list(video)

model_inputs = self.image_processor(video, return_tensors=self.framework)
if self.framework == "pt":
model_inputs = model_inputs.to(self.torch_dtype)
return model_inputs

def _forward(self, model_inputs):
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/pipelines/visual_question_answering.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,8 @@ def preprocess(self, inputs, padding=False, truncation=False, timeout=None):
truncation=truncation,
)
image_features = self.image_processor(images=image, return_tensors=self.framework)
if self.framework == "pt":
image_features = image_features.to(self.torch_dtype)
model_inputs.update(image_features)
return model_inputs

Expand Down
2 changes: 2 additions & 0 deletions src/transformers/pipelines/zero_shot_audio_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@ def preprocess(self, audio, candidate_labels=None, hypothesis_template="This is
inputs = self.feature_extractor(
[audio], sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt"
)
if self.framework == "pt":
inputs = inputs.to(self.torch_dtype)
inputs["candidate_labels"] = candidate_labels
sequences = [hypothesis_template.format(x) for x in candidate_labels]
text_inputs = self.tokenizer(sequences, return_tensors=self.framework, padding=True)
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/pipelines/zero_shot_image_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ def _sanitize_parameters(self, **kwargs):
def preprocess(self, image, candidate_labels=None, hypothesis_template="This is a photo of {}.", timeout=None):
image = load_image(image, timeout=timeout)
inputs = self.image_processor(images=[image], return_tensors=self.framework)
if self.framework == "pt":
inputs = inputs.to(self.torch_dtype)
inputs["candidate_labels"] = candidate_labels
sequences = [hypothesis_template.format(x) for x in candidate_labels]
padding = "max_length" if self.model.config.model_type == "siglip" else True
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/pipelines/zero_shot_object_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,8 @@ def preprocess(self, inputs, timeout=None):
for i, candidate_label in enumerate(candidate_labels):
text_inputs = self.tokenizer(candidate_label, return_tensors=self.framework)
image_features = self.image_processor(image, return_tensors=self.framework)
if self.framework == "pt":
image_features = image_features.to(self.torch_dtype)
yield {
"is_last": i == len(candidate_labels) - 1,
"target_size": target_size,
Expand Down
6 changes: 4 additions & 2 deletions tests/pipelines/test_pipelines_audio_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,10 @@ class AudioClassificationPipelineTests(unittest.TestCase):
model_mapping = MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING
tf_model_mapping = TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING

def get_test_pipeline(self, model, tokenizer, processor):
audio_classifier = AudioClassificationPipeline(model=model, feature_extractor=processor)
def get_test_pipeline(self, model, tokenizer, processor, torch_dtype="float32"):
audio_classifier = AudioClassificationPipeline(
model=model, feature_extractor=processor, torch_dtype=torch_dtype
)

# test with a raw waveform
audio = np.zeros((34000,))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,14 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
+ (MODEL_FOR_CTC_MAPPING.items() if MODEL_FOR_CTC_MAPPING else [])
)

def get_test_pipeline(self, model, tokenizer, processor):
def get_test_pipeline(self, model, tokenizer, processor, torch_dtype="float32"):
if tokenizer is None:
# Side effect of no Fast Tokenizer class for these model, so skipping
# But the slow tokenizer test should still run as they're quite small
self.skipTest(reason="No tokenizer available")

speech_recognizer = AutomaticSpeechRecognitionPipeline(
model=model, tokenizer=tokenizer, feature_extractor=processor
model=model, tokenizer=tokenizer, feature_extractor=processor, torch_dtype=torch_dtype
)

# test with a raw waveform
Expand Down
4 changes: 2 additions & 2 deletions tests/pipelines/test_pipelines_depth_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ def hashimage(image: Image) -> str:
class DepthEstimationPipelineTests(unittest.TestCase):
model_mapping = MODEL_FOR_DEPTH_ESTIMATION_MAPPING

def get_test_pipeline(self, model, tokenizer, processor):
depth_estimator = DepthEstimationPipeline(model=model, image_processor=processor)
def get_test_pipeline(self, model, tokenizer, processor, torch_dtype="float32"):
depth_estimator = DepthEstimationPipeline(model=model, image_processor=processor, torch_dtype=torch_dtype)
return depth_estimator, [
"./tests/fixtures/tests_samples/COCO/000000039769.png",
"./tests/fixtures/tests_samples/COCO/000000039769.png",
Expand Down
8 changes: 6 additions & 2 deletions tests/pipelines/test_pipelines_document_question_answering.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,13 @@ class DocumentQuestionAnsweringPipelineTests(unittest.TestCase):

@require_pytesseract
@require_vision
def get_test_pipeline(self, model, tokenizer, processor):
def get_test_pipeline(self, model, tokenizer, processor, torch_dtype="float32"):
dqa_pipeline = pipeline(
"document-question-answering", model=model, tokenizer=tokenizer, image_processor=processor
"document-question-answering",
model=model,
tokenizer=tokenizer,
image_processor=processor,
torch_dtype=torch_dtype,
)

image = INVOICE_URL
Expand Down
6 changes: 4 additions & 2 deletions tests/pipelines/test_pipelines_feature_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def get_shape(self, input_, shape=None):
raise ValueError("We expect lists of floats, nothing else")
return shape

def get_test_pipeline(self, model, tokenizer, processor):
def get_test_pipeline(self, model, tokenizer, processor, torch_dtype="float32"):
if tokenizer is None:
self.skipTest(reason="No tokenizer")
elif (
Expand All @@ -193,7 +193,9 @@ def get_test_pipeline(self, model, tokenizer, processor):
For now ignore those.
"""
)
feature_extractor = FeatureExtractionPipeline(model=model, tokenizer=tokenizer, feature_extractor=processor)
feature_extractor = FeatureExtractionPipeline(
model=model, tokenizer=tokenizer, feature_extractor=processor, torch_dtype=torch_dtype
)
return feature_extractor, ["This is a test", "This is another test"]

def run_pipeline_test(self, feature_extractor, examples):
Expand Down
4 changes: 2 additions & 2 deletions tests/pipelines/test_pipelines_fill_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,11 +251,11 @@ def test_model_no_pad_tf(self):
unmasker.tokenizer.pad_token = None
self.run_pipeline_test(unmasker, [])

def get_test_pipeline(self, model, tokenizer, processor):
def get_test_pipeline(self, model, tokenizer, processor, torch_dtype="float32"):
if tokenizer is None or tokenizer.mask_token_id is None:
self.skipTest(reason="The provided tokenizer has no mask token, (probably reformer or wav2vec2)")

fill_masker = FillMaskPipeline(model=model, tokenizer=tokenizer)
fill_masker = FillMaskPipeline(model=model, tokenizer=tokenizer, torch_dtype=torch_dtype)
examples = [
f"This is another {tokenizer.mask_token} test",
]
Expand Down
6 changes: 4 additions & 2 deletions tests/pipelines/test_pipelines_image_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,10 @@ class ImageClassificationPipelineTests(unittest.TestCase):
model_mapping = MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
tf_model_mapping = TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING

def get_test_pipeline(self, model, tokenizer, processor):
image_classifier = ImageClassificationPipeline(model=model, image_processor=processor, top_k=2)
def get_test_pipeline(self, model, tokenizer, processor, torch_dtype="float32"):
image_classifier = ImageClassificationPipeline(
model=model, image_processor=processor, top_k=2, torch_dtype=torch_dtype
)
examples = [
Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png"),
"http://images.cocodataset.org/val2017/000000039769.jpg",
Expand Down
6 changes: 4 additions & 2 deletions tests/pipelines/test_pipelines_image_feature_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def test_return_tensors_tf(self):
outputs = feature_extractor(img, return_tensors=True)
self.assertTrue(tf.is_tensor(outputs))

def get_test_pipeline(self, model, tokenizer, processor):
def get_test_pipeline(self, model, tokenizer, processor, torch_dtype="float32"):
if processor is None:
self.skipTest(reason="No image processor")

Expand All @@ -175,7 +175,9 @@ def get_test_pipeline(self, model, tokenizer, processor):
"""
)

feature_extractor = ImageFeatureExtractionPipeline(model=model, image_processor=processor)
feature_extractor = ImageFeatureExtractionPipeline(
model=model, image_processor=processor, torch_dtype=torch_dtype
)
img = prepare_img()
return feature_extractor, [img, img]

Expand Down
4 changes: 2 additions & 2 deletions tests/pipelines/test_pipelines_image_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ class ImageSegmentationPipelineTests(unittest.TestCase):
+ (MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING.items() if MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING else [])
)

def get_test_pipeline(self, model, tokenizer, processor):
image_segmenter = ImageSegmentationPipeline(model=model, image_processor=processor)
def get_test_pipeline(self, model, tokenizer, processor, torch_dtype="float32"):
image_segmenter = ImageSegmentationPipeline(model=model, image_processor=processor, torch_dtype=torch_dtype)
return image_segmenter, [
"./tests/fixtures/tests_samples/COCO/000000039769.png",
"./tests/fixtures/tests_samples/COCO/000000039769.png",
Expand Down
10 changes: 8 additions & 2 deletions tests/pipelines/test_pipelines_image_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,9 @@ class ImageToImagePipelineTests(unittest.TestCase):
@require_torch
@require_vision
@slow
def test_pipeline(self):
def test_pipeline(self, torch_dtype="float32"):
model_id = "caidas/swin2SR-classical-sr-x2-64"
upscaler = pipeline("image-to-image", model=model_id)
upscaler = pipeline("image-to-image", model=model_id, torch_dtype=torch_dtype)
upscaled_list = upscaler(self.examples)

self.assertEqual(len(upscaled_list), len(self.examples))
Expand All @@ -66,6 +66,12 @@ def test_pipeline(self):
self.assertEqual(upscaled_list[0].size, (1296, 976))
self.assertEqual(upscaled_list[1].size, (1296, 976))

@require_torch
@require_vision
@slow
def test_pipeline_fp16(self):
self.test_pipeline(torch_dtype="float16")

@require_torch
@require_vision
@slow
Expand Down
6 changes: 4 additions & 2 deletions tests/pipelines/test_pipelines_image_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,10 @@ class ImageToTextPipelineTests(unittest.TestCase):
model_mapping = MODEL_FOR_VISION_2_SEQ_MAPPING
tf_model_mapping = TF_MODEL_FOR_VISION_2_SEQ_MAPPING

def get_test_pipeline(self, model, tokenizer, processor):
pipe = pipeline("image-to-text", model=model, tokenizer=tokenizer, image_processor=processor)
def get_test_pipeline(self, model, tokenizer, processor, torch_dtype="float32"):
pipe = pipeline(
"image-to-text", model=model, tokenizer=tokenizer, image_processor=processor, torch_dtype=torch_dtype
)
examples = [
Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png"),
"./tests/fixtures/tests_samples/COCO/000000039769.png",
Expand Down
4 changes: 2 additions & 2 deletions tests/pipelines/test_pipelines_mask_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ class MaskGenerationPipelineTests(unittest.TestCase):
(list(TF_MODEL_FOR_MASK_GENERATION_MAPPING.items()) if TF_MODEL_FOR_MASK_GENERATION_MAPPING else [])
)

def get_test_pipeline(self, model, tokenizer, processor):
image_segmenter = MaskGenerationPipeline(model=model, image_processor=processor)
def get_test_pipeline(self, model, tokenizer, processor, torch_dtype="float32"):
image_segmenter = MaskGenerationPipeline(model=model, image_processor=processor, torch_dtype=torch_dtype)
return image_segmenter, [
"./tests/fixtures/tests_samples/COCO/000000039769.png",
"./tests/fixtures/tests_samples/COCO/000000039769.png",
Expand Down
Loading

0 comments on commit ac26260

Please sign in to comment.