Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring of ImageProcessorFast #35069

Merged
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
2f00f0c
add init and base image processing functions
yonigozlan Dec 3, 2024
cfadb72
add add_fast_image_processor to transformers-cli
yonigozlan Dec 3, 2024
2cd73cb
add working fast image processor clip
yonigozlan Dec 3, 2024
932bd68
add fast image processor to doc, working tests
yonigozlan Dec 4, 2024
23d79ce
remove "to be implemented" SigLip
yonigozlan Dec 4, 2024
3f2d8a6
fix unprotected import
yonigozlan Dec 4, 2024
6a9d332
fix unprotected vision import
yonigozlan Dec 4, 2024
a1e2663
update ViTImageProcessorFast
yonigozlan Dec 4, 2024
fa74e7e
increase threshold slow fast ewuivalence
yonigozlan Dec 4, 2024
9dbd765
add fast img blip
yonigozlan Dec 4, 2024
d39ff52
add fast class in tests with cli
yonigozlan Dec 4, 2024
f609730
improve cli
yonigozlan Dec 5, 2024
8f7774d
add fast image processor convnext
yonigozlan Dec 6, 2024
809e1f0
add LlavaPatchingMixin and fast image processor for llava_next and ll…
yonigozlan Dec 7, 2024
f6e6cc2
add device kwarg to ImagesKwargs for fast processing on cuda
yonigozlan Dec 9, 2024
e1ce148
cleanup
yonigozlan Dec 9, 2024
a24d89c
fix unprotected import
yonigozlan Dec 9, 2024
522e200
group images by sizes and add batch processing
yonigozlan Dec 11, 2024
deefc5a
Add batch equivalence tests, skip when center_crop is used
yonigozlan Dec 11, 2024
6a2478e
cleanup
yonigozlan Dec 11, 2024
7d76305
update init and cli
yonigozlan Dec 11, 2024
142ed25
fix-copies
yonigozlan Dec 11, 2024
75bf56f
refactor convnext, cleanup base
yonigozlan Dec 16, 2024
de1fa18
fix
yonigozlan Dec 16, 2024
2ffc41d
remove patching mixins, add piped torchvision transforms for ViT
yonigozlan Dec 17, 2024
b524406
fix unbatched processing
yonigozlan Dec 17, 2024
9c2e2a4
fix f strings
yonigozlan Dec 17, 2024
8c773e0
protect imports
yonigozlan Dec 17, 2024
90fceba
change llava onevision to class transforms (test)
yonigozlan Dec 18, 2024
e878bdd
fix convnext
yonigozlan Dec 18, 2024
57acb7e
improve formatting (following Pavel review)
yonigozlan Jan 6, 2025
2a25104
fix handling device arg
yonigozlan Jan 6, 2025
4784fc8
improve cli
yonigozlan Jan 6, 2025
3ccd291
fix
yonigozlan Jan 6, 2025
053cdcb
fix inits
yonigozlan Jan 16, 2025
1b45e6e
Merge remote-tracking branch 'upstream/main' into improve-fast-image-…
yonigozlan Jan 21, 2025
9246945
Add distinction between preprocess and _preprocess, and support for a…
yonigozlan Jan 21, 2025
6ccd230
uniformize qwen2_vl fast
yonigozlan Jan 22, 2025
c4b8389
fix docstrings
yonigozlan Jan 22, 2025
e5c1e01
add add fast image processor llava
yonigozlan Jan 22, 2025
aef2fb4
remove min_pixels max_pixels from accepted size
yonigozlan Jan 22, 2025
7078a14
nit
yonigozlan Jan 22, 2025
aa94873
nit
yonigozlan Jan 22, 2025
13a125b
refactor fast image processors docstrings
yonigozlan Jan 28, 2025
8adb893
Merge remote-tracking branch 'upstream/main' into improve-fast-image-…
yonigozlan Jan 28, 2025
67d65f2
cleanup and remove fast class transforms
yonigozlan Jan 28, 2025
d225448
update add fast image processor transformers cli
yonigozlan Jan 28, 2025
80c6824
cleanup docstring
yonigozlan Jan 28, 2025
b96adfa
Merge remote-tracking branch 'upstream/main' into improve-fast-image-…
yonigozlan Jan 30, 2025
dbaacd1
uniformize pixtral fast and make _process_image explicit
yonigozlan Jan 30, 2025
b660e9d
Merge remote-tracking branch 'upstream/main' into improve-fast-image-…
yonigozlan Jan 30, 2025
b43ede1
fix prepare image structure llava next/onevision
yonigozlan Jan 30, 2025
3b05cbd
Use typed kwargs instead of explicit args
yonigozlan Feb 4, 2025
95db4a9
nit fix import Unpack
yonigozlan Feb 4, 2025
d9e1fcd
Merge branch 'main' into improve-fast-image-processor-base
yonigozlan Feb 4, 2025
6bd7a1b
clearly separate pops and gets in base preprocess. Use explicit typed…
yonigozlan Feb 4, 2025
565e482
Merge branch 'main' into improve-fast-image-processor-base
yonigozlan Feb 4, 2025
f85c06f
make qwen2_vl preprocess arguments hashable
yonigozlan Feb 4, 2025
1a7b0c4
Merge branch 'improve-fast-image-processor-base' of https://github.co…
yonigozlan Feb 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/source/en/model_doc/llava.md
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,11 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h
[[autodoc]] LlavaImageProcessor
- preprocess

## LlavaImageProcessorFast

[[autodoc]] LlavaImageProcessorFast
- preprocess

## LlavaProcessor

[[autodoc]] LlavaProcessor
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1302,6 +1302,7 @@
_import_structure["models.deformable_detr"].append("DeformableDetrImageProcessorFast")
_import_structure["models.deit"].append("DeiTImageProcessorFast")
_import_structure["models.detr"].append("DetrImageProcessorFast")
_import_structure["models.llava"].append("LlavaImageProcessorFast")
_import_structure["models.llava_next"].append("LlavaNextImageProcessorFast")
_import_structure["models.llava_onevision"].append("LlavaOnevisionImageProcessorFast")
_import_structure["models.pixtral"].append("PixtralImageProcessorFast")
Expand Down Expand Up @@ -6408,6 +6409,7 @@
from .models.deformable_detr import DeformableDetrImageProcessorFast
from .models.deit import DeiTImageProcessorFast
from .models.detr import DetrImageProcessorFast
from .models.llava import LlavaImageProcessorFast
from .models.llava_next import LlavaNextImageProcessorFast
from .models.llava_onevision import LlavaOnevisionImageProcessorFast
from .models.pixtral import PixtralImageProcessorFast
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/commands/add_fast_image_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def add_import_statement_init(content: str, fast_image_processor_name: str, mode
"""
# Step 1: Find the block
block_regex = re.compile(
r"if not is_torchvision_available\(\):\s+raise OptionalDependencyNotAvailable\(\)\s+except OptionalDependencyNotAvailable:\s+from \.utils\.dummy_torchvision_objects import \*\s+else:(?P<else_block>\s*(\n\s*from .+ import .*\n)+)(?=\s*# Modeling)",
r"if not is_torchvision_available\(\):\s+raise OptionalDependencyNotAvailable\(\)\s+except OptionalDependencyNotAvailable:\s+from \.utils\.dummy_torchvision_objects import \*\s+else:(?P<else_block>\s*(\n\s*from .+ import .*\n)+)(?=\s*try:\s+if not \(is_torchvision_available\(\) and is_timm_available\(\)\):)",
re.DOTALL,
)
match = block_regex.search(content)
Expand Down Expand Up @@ -560,6 +560,7 @@ def add_fast_image_processor_file(
" # if the image processor is only used for simple augmentations, such as resizing, center cropping, rescaling, or normalizing,\n"
" # only the default values should be set in the class.\n"
" # If the image processor requires more complex augmentations, methods from BaseImageProcessorFast can be overridden.\n"
" # In most cases, only the `_preprocess` method should be overridden.\n\n"
" # For an example of a fast image processor requiring more complex augmentations, see `LlavaOnevisionImageProcessorFast`.\n\n"
" # Default values should be checked against the slow image processor\n"
" # None values left after checking can be removed\n"
Expand Down
1 change: 0 additions & 1 deletion src/transformers/commands/transformers_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ def main():
UserCommands.register_subcommand(commands_parser)
AddNewModelLikeCommand.register_subcommand(commands_parser)
LfsCommands.register_subcommand(commands_parser)
PTtoTFCommand.register_subcommand(commands_parser)
AddFastImageProcessorCommand.register_subcommand(commands_parser)

# Let's go
Expand Down
111 changes: 88 additions & 23 deletions src/transformers/image_processing_utils_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from concurrent.futures import ThreadPoolExecutor
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -48,6 +49,7 @@
is_torchvision_available,
is_torchvision_v2_available,
is_vision_available,
logging,
)


Expand All @@ -65,6 +67,8 @@
else:
from torchvision.transforms import functional as F

logger = logging.get_logger(__name__)


def safe_squeeze(tensor: "torch.Tensor", axis: Optional[int] = None) -> "torch.Tensor":
"""
Expand Down Expand Up @@ -131,6 +135,8 @@ class BaseImageProcessorFast(BaseImageProcessor):
size (`dict`, *optional*):
Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess`
method.
default_to_square (`bool`, *optional*):
Whether to default to a square image when resizing, if size is an int.
resample (`PILImageResampling`, *optional*):
Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`. Can be
overridden by the `resample` parameter in the `preprocess` method.
Expand Down Expand Up @@ -173,12 +179,13 @@ class BaseImageProcessorFast(BaseImageProcessor):
do_normalize = None
do_convert_rgb = None
model_input_names = ["pixel_values"]
valid_extra_kwargs = ["default_to_square"]
valid_extra_kwargs = []

def __init__(
self,
do_resize: bool = None,
size: Dict[str, int] = None,
default_to_square: bool = None,
resample: Union["PILImageResampling", "F.InterpolationMode"] = None,
do_center_crop: bool = None,
crop_size: Dict[str, int] = None,
Expand All @@ -191,8 +198,12 @@ def __init__(
**kwargs,
) -> None:
size = size if size is not None else self.size
default_to_square = kwargs.pop(
"default_to_square", self.default_to_square if self.default_to_square is not None else True
default_to_square = (
default_to_square
if default_to_square is not None
else self.default_to_square
if self.default_to_square is not None
else True
)
size = get_size_dict(size, default_to_square=default_to_square) if size is not None else None
crop_size = crop_size if crop_size is not None else self.crop_size
Expand All @@ -210,6 +221,13 @@ def __init__(
self.image_mean = image_mean if image_mean is not None else self.image_mean
self.image_std = image_std if image_std is not None else self.image_std
self.do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
for key in self.valid_extra_kwargs:
if kwargs.get(key) is not None:
setattr(self, key, kwargs.pop(key))
else:
setattr(self, key, getattr(self, key, None))
if kwargs:
logger.warning_once(f"Found kwargs that are not in valid_extra_kwargs: {kwargs.keys()}")

def resize(
self,
Expand Down Expand Up @@ -393,27 +411,35 @@ def _prepare_input_images(
if image_type not in [ImageType.PIL, ImageType.TORCH, ImageType.NUMPY]:
raise ValueError(f"Unsupported input image type {image_type}")

if do_convert_rgb:
images = [self.convert_to_rgb(image) for image in images]
def process_image(image):
if do_convert_rgb:
image = self.convert_to_rgb(image)

if image_type == ImageType.PIL:
images = [F.pil_to_tensor(image) for image in images]
elif image_type == ImageType.NUMPY:
# not using F.to_tensor as it doesn't handle (C, H, W) numpy arrays
images = [torch.from_numpy(image).contiguous() for image in images]
if image_type == ImageType.PIL:
image = F.pil_to_tensor(image)
elif image_type == ImageType.NUMPY:
# not using F.to_tensor as it doesn't handle (C, H, W) numpy arrays
image = torch.from_numpy(image).contiguous()

# Now that we have torch tensors, we can move them to the right device
if device is not None:
images = [image.to(device) for image in images]
# Infer the channel dimension format if not provided
nonlocal input_data_format
if input_data_format is None:
input_data_format = infer_channel_dimension_format(image)

# We assume that all images have the same channel dimension format.
if input_data_format is None:
input_data_format = infer_channel_dimension_format(images[0])
if input_data_format == ChannelDimension.LAST:
# We force the channel dimension to be first for torch tensors as this is what torchvision expects.
images = [image.permute(2, 0, 1).contiguous() for image in images]
if input_data_format == ChannelDimension.LAST:
# We force the channel dimension to be first for torch tensors as this is what torchvision expects.
image = image.permute(2, 0, 1).contiguous()

return images
# Now that we have torch tensors, we can move them to the right device
if device is not None:
image = image.to(device)

return image

with ThreadPoolExecutor() as executor:
processed_images = list(executor.map(process_image, images))

return processed_images

def _prepare_process_arguments(
self,
Expand Down Expand Up @@ -515,8 +541,7 @@ def preprocess(
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
Whether to convert the image to RGB.
return_tensors (`str` or `TensorType`, *optional*):
The type of tensors to return. Default to `"pt"` for PyTorch tensors if unset.
Fast image processors only support PyTorch tensors.
Returns stacked tensors if set to `pt, otherwise returns a list of tensors.
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
The channel dimension format for the output image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
Expand All @@ -530,6 +555,8 @@ def preprocess(
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
device (`torch.device`, *optional*):
The device to process the images on. If unset, the device is inferred from the input images.
kwargs:
Model-specific arguments to pass to the `_preprocess` method.
"""
validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self.valid_extra_kwargs)

Expand All @@ -549,7 +576,10 @@ def preprocess(
image_mean = image_mean if image_mean is not None else self.image_mean
image_std = image_std if image_std is not None else self.image_std
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb

kwargs_dict = {
kwarg: kwargs.pop(kwarg) if kwargs.get(kwarg) is not None else getattr(self, kwarg, None)
for kwarg in self.valid_extra_kwargs
}
images = self._prepare_input_images(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here you will have a n issue: you are poping way too many time from the kwargs, so you pop then you use self, values will be different.
This should be simplified. Just update self with kwargs, then do the rest with always self. something like that

images=images,
do_convert_rgb=do_convert_rgb,
Expand All @@ -572,6 +602,41 @@ def preprocess(
device=images[0].device,
)

return self._preprocess(
images=images,
do_resize=do_resize,
size=size,
interpolation=interpolation,
do_center_crop=do_center_crop,
crop_size=crop_size,
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
return_tensors=return_tensors,
**kwargs_dict,
)

def _preprocess(
self,
images: List["torch.Tensor"],
do_resize: bool,
size: SizeDict,
interpolation: Optional["F.InterpolationMode"],
do_center_crop: bool,
crop_size: SizeDict,
do_rescale: bool,
rescale_factor: float,
do_normalize: bool,
image_mean: Optional[Union[float, List[float]]],
image_std: Optional[Union[float, List[float]]],
return_tensors: Optional[Union[str, TensorType]],
) -> BatchFeature:
"""
Preprocess images.
"""

# Group images by size for batched resizing
grouped_images, grouped_images_index = group_images_by_shape(images)
resized_images_grouped = {}
Expand Down
Loading