Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix output data type of image classification #31444

Merged
merged 9 commits into from
Jun 25, 2024

Conversation

jiqing-feng
Copy link
Contributor

Hi @Narsil . I found an issue when running the low-precision pipeline of image-classification:

import torch
import requests
from transformers import pipeline
import PIL.Image

IMG_URL = "http://images.cocodataset.org/val2017/000000039769.jpg"
model_id = "google/vit-base-patch16-224"
classifier = pipeline("image-classification", model=model_id, torch_dtype=torch.bfloat16)
image = PIL.Image.open(requests.get(IMG_URL, stream=True, timeout=3000).raw)

output = classifier(image)

Error log:

Traceback (most recent call last):
  File "test_image_pipeline.py", line 11, in <module>
    output = classifier(image)
  File "/home/jiqingfe/miniconda3/envs/ipex/lib/python3.8/site-packages/transformers/pipelines/image_classification.py", line 157, in __call__
    return super().__call__(images, **kwargs)
  File "/home/jiqingfe/miniconda3/envs/ipex/lib/python3.8/site-packages/transformers/pipelines/base.py", line 1243, in __call__
    return self.run_single(inputs, preprocess_params, forward_params, postprocess_params)
  File "/home/jiqingfe/miniconda3/envs/ipex/lib/python3.8/site-packages/transformers/pipelines/base.py", line 1251, in run_single
    outputs = self.postprocess(model_outputs, **postprocess_params)
  File "/home/jiqingfe/miniconda3/envs/ipex/lib/python3.8/site-packages/transformers/pipelines/image_classification.py", line 183, in postprocess
    outputs = outputs.numpy()
TypeError: Got unsupported ScalarType BFloat16

We need to convert the output tensor to float32 so it can be converted to numpy.

@jiqing-feng
Copy link
Contributor Author

Also, for the failed CI, it seems like not all tensors can be converted to fp32?

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on a fix for this!

We'll need to protect the float cast and add tests to check the pipeline works when torch_dtype is set to float16 or bfloat16

@jiqing-feng
Copy link
Contributor Author

jiqing-feng commented Jun 18, 2024

Hi @amyeroberts . Thanks for your review. I have fixed your comments and added a test for the low-precision pipeline, but the CI seems to be going wrong. Can you help check the test? Thx!

And other task pipelines also have this issue. I would like to hear your opinion and then change the rest of pipelines :)

@jiqing-feng
Copy link
Contributor Author

Hi @amyeroberts. Can you take a look at the new changes? If it is okay for you, I will update all other pipelines to support low precision.

@jiqing-feng
Copy link
Contributor Author

Hi @amyeroberts . Could you please take a review? The failed CIs are not related to my changes :)

I am waiting for your comments to change the rest of the pipelines to support it. Thx!

@jiqing-feng
Copy link
Contributor Author

Hi @amyeroberts . I have rebased and passed all CIs. Would you please review them?

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding!

Just a small request to add a test for bfloat16 too

@jiqing-feng
Copy link
Contributor Author

Hi @amyeroberts . I have added fp16 and bf16 tests for the image classification pipeline; could you please review them? Thx!

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing!

Just some small nits and we're ready to merge!

jiqing-feng and others added 3 commits June 25, 2024 04:51
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
@jiqing-feng
Copy link
Contributor Author

HI @amyeroberts . I have fixed the import issue, please take a review, thx!

@amyeroberts amyeroberts merged commit a958c4a into huggingface:main Jun 25, 2024
18 checks passed
@jiqing-feng jiqing-feng deleted the image-classification branch June 26, 2024 03:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants