Skip to content

Commit

Permalink
add Qwen2-VL image processor fast (huggingface#35733)
Browse files Browse the repository at this point in the history
* add qwen2_vl image processor fast

* add device to ImagesKwargs

* remove automatic fix copies

* fix fast_is_faster_than_slow

* remove unnecessary import
  • Loading branch information
yonigozlan authored and bursteratom committed Jan 28, 2025
1 parent 3f5e26d commit 1f266fe
Show file tree
Hide file tree
Showing 9 changed files with 596 additions and 139 deletions.
5 changes: 5 additions & 0 deletions docs/source/en/model_doc/qwen2_vl.md
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,11 @@ model = Qwen2VLForConditionalGeneration.from_pretrained(
[[autodoc]] Qwen2VLImageProcessor
- preprocess

## Qwen2VLImageProcessorFast

[[autodoc]] Qwen2VLImageProcessorFast
- preprocess

## Qwen2VLProcessor

[[autodoc]] Qwen2VLProcessor
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1299,6 +1299,7 @@
_import_structure["models.deformable_detr"].append("DeformableDetrImageProcessorFast")
_import_structure["models.detr"].append("DetrImageProcessorFast")
_import_structure["models.pixtral"].append("PixtralImageProcessorFast")
_import_structure["models.qwen2_vl"].append("Qwen2VLImageProcessorFast")
_import_structure["models.rt_detr"].append("RTDetrImageProcessorFast")
_import_structure["models.vit"].append("ViTImageProcessorFast")

Expand Down Expand Up @@ -6397,6 +6398,7 @@
from .models.deformable_detr import DeformableDetrImageProcessorFast
from .models.detr import DetrImageProcessorFast
from .models.pixtral import PixtralImageProcessorFast
from .models.qwen2_vl import Qwen2VLImageProcessorFast
from .models.rt_detr import RTDetrImageProcessorFast
from .models.vit import ViTImageProcessorFast

Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/auto/image_processing_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@
("poolformer", ("PoolFormerImageProcessor",)),
("pvt", ("PvtImageProcessor",)),
("pvt_v2", ("PvtImageProcessor",)),
("qwen2_vl", ("Qwen2VLImageProcessor",)),
("qwen2_vl", ("Qwen2VLImageProcessor", "Qwen2VLImageProcessorFast")),
("regnet", ("ConvNextImageProcessor",)),
("resnet", ("ConvNextImageProcessor",)),
("rt_detr", ("RTDetrImageProcessor", "RTDetrImageProcessorFast")),
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/qwen2_vl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
if TYPE_CHECKING:
from .configuration_qwen2_vl import *
from .image_processing_qwen2_vl import *
from .image_processing_qwen2_vl_fast import *
from .modeling_qwen2_vl import *
from .processing_qwen2_vl import *
else:
Expand Down
422 changes: 422 additions & 0 deletions src/transformers/models/qwen2_vl/image_processing_qwen2_vl_fast.py

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions src/transformers/processing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,8 @@ class methods and docstrings.
The channel dimension format for the output image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image.
device (`str`, *optional*):
The device to use for processing (e.g. "cpu", "cuda"), only relevant for fast image processing.
"""

do_resize: Optional[bool]
Expand All @@ -188,6 +190,7 @@ class methods and docstrings.
do_center_crop: Optional[bool]
data_format: Optional[ChannelDimension]
input_data_format: Optional[Union[str, ChannelDimension]]
device: Optional[str]


class VideosKwargs(TypedDict, total=False):
Expand Down
7 changes: 7 additions & 0 deletions src/transformers/utils/dummy_torchvision_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,13 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["torchvision"])


class Qwen2VLImageProcessorFast(metaclass=DummyObject):
_backends = ["torchvision"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torchvision"])


class RTDetrImageProcessorFast(metaclass=DummyObject):
_backends = ["torchvision"]

Expand Down
286 changes: 149 additions & 137 deletions tests/models/qwen2_vl/test_image_processing_qwen2_vl.py

Large diffs are not rendered by default.

7 changes: 6 additions & 1 deletion tests/test_image_processing_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,10 @@ def test_slow_fast_equivalence(self):
encoding_slow = image_processor_slow(dummy_image, return_tensors="pt")
encoding_fast = image_processor_fast(dummy_image, return_tensors="pt")

self.assertTrue(torch.allclose(encoding_slow.pixel_values, encoding_fast.pixel_values, atol=1e-2))
self.assertTrue(torch.allclose(encoding_slow.pixel_values, encoding_fast.pixel_values, atol=1e-1))
self.assertLessEqual(
torch.mean(torch.abs(encoding_slow.pixel_values - encoding_fast.pixel_values)).item(), 1e-3
)

@require_vision
@require_torch
Expand All @@ -193,6 +196,8 @@ def test_fast_is_faster_than_slow(self):
self.skipTest(reason="Skipping speed test as one of the image processors is not defined")

def measure_time(image_processor, image):
# Warmup
_ = image_processor(image, return_tensors="pt")
start = time.time()
_ = image_processor(image, return_tensors="pt")
return time.time() - start
Expand Down

0 comments on commit 1f266fe

Please sign in to comment.