From f94b391b68dedc4314544f79dfbe4284fbe9a0d9 Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Thu, 30 Jan 2025 12:40:18 +0100 Subject: [PATCH] Pixtral: vectorize patch embeddings and enable tests (#35122) * initial POC * - batch mix feature * fix tests * fix tests * make style * do not skip and instead fix tests * update * return back the test * correct text with the correct ckpt --- .../models/llava/modeling_llava.py | 6 +- .../pixtral/image_processing_pixtral.py | 224 +++++++----------- .../pixtral/image_processing_pixtral_fast.py | 147 +++++++----- .../models/pixtral/modeling_pixtral.py | 61 +++-- .../models/pixtral/processing_pixtral.py | 92 ++----- tests/models/llava/test_modeling_llava.py | 77 ++++-- .../pixtral/test_image_processing_pixtral.py | 130 +++++----- tests/models/pixtral/test_modeling_pixtral.py | 124 ++-------- .../models/pixtral/test_processor_pixtral.py | 116 +++++---- tests/test_modeling_common.py | 4 + 10 files changed, 429 insertions(+), 552 deletions(-) diff --git a/src/transformers/models/llava/modeling_llava.py b/src/transformers/models/llava/modeling_llava.py index 7b4055072282..67313c8f55d3 100644 --- a/src/transformers/models/llava/modeling_llava.py +++ b/src/transformers/models/llava/modeling_llava.py @@ -280,6 +280,7 @@ def get_image_features( pixel_values: torch.FloatTensor, vision_feature_layer: Union[int, List[int]], vision_feature_select_strategy: str, + **kwargs, ): """ Obtains image last hidden states from the vision tower and apply multimodal projection. @@ -300,8 +301,9 @@ def get_image_features( if vision_feature_select_strategy not in ["default", "full"]: raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}") + kwargs = {k: v for k, v in kwargs.items() if v is not None} # this is not memory efficient at all (output_hidden_states=True) will save all the hidden states. - image_outputs = self.vision_tower(pixel_values, output_hidden_states=True) + image_outputs = self.vision_tower(pixel_values, output_hidden_states=True, **kwargs) # If we have one vision feature layer, return the corresponding hidden states, # otherwise, select the hidden states of each feature layer and concatenate them @@ -422,6 +424,7 @@ def forward( return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, + image_sizes: torch.Tensor = None, ) -> Union[Tuple, LlavaCausalLMOutputWithPast]: r""" Args: @@ -492,6 +495,7 @@ def forward( pixel_values=pixel_values, vision_feature_layer=vision_feature_layer, vision_feature_select_strategy=vision_feature_select_strategy, + image_sizes=image_sizes, ) n_image_tokens = (input_ids == self.config.image_token_index).sum().item() diff --git a/src/transformers/models/pixtral/image_processing_pixtral.py b/src/transformers/models/pixtral/image_processing_pixtral.py index 6d83e0c46471..969575d2e49a 100644 --- a/src/transformers/models/pixtral/image_processing_pixtral.py +++ b/src/transformers/models/pixtral/image_processing_pixtral.py @@ -15,12 +15,13 @@ """Image processor class for Pixtral.""" import math -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import numpy as np from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict from ...image_transforms import ( + pad, resize, to_channel_dimension_format, ) @@ -31,13 +32,13 @@ get_image_size, infer_channel_dimension_format, is_scaled_image, - is_valid_image, + make_list_of_images, to_numpy_array, valid_images, validate_kwargs, validate_preprocess_arguments, ) -from ...utils import TensorType, is_torch_device, is_torch_dtype, is_vision_available, logging +from ...utils import TensorType, is_vision_available, logging from ...utils.import_utils import requires_backends @@ -48,91 +49,6 @@ import PIL -class BatchMixFeature(BatchFeature): - def to(self, *args, **kwargs) -> "BatchMixFeature": - """ - Send all values to device by calling `v.to(*args, **kwargs)` (PyTorch only). This should support casting in - different `dtypes` and sending the `BatchFeature` to a different `device`. - - Args: - args (`Tuple`): - Will be passed to the `to(...)` function of the tensors. - kwargs (`Dict`, *optional*): - Will be passed to the `to(...)` function of the tensors. - - Returns: - [`BatchFeature`]: The same instance after modification. - """ - - def _recursive_to(obj, device, *args, **kwargs): - # Lists can be nested, so keep digging until we hit tensors - if isinstance(obj, list): - return [_recursive_to(o, device, *args, **kwargs) for o in obj] - # We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor` - elif isinstance(obj, torch.Tensor) and torch.is_floating_point(obj): - # cast and send to device - return obj.to(*args, **kwargs) - elif isinstance(obj, torch.Tensor) and device is not None: - # only send to device, don't cast - return obj.to(device=device) - else: - return obj - - requires_backends(self, ["torch"]) - import torch # noqa - - device = kwargs.get("device") - # Check if the args are a device or a dtype - if device is None and len(args) > 0: - # device should be always the first argument - arg = args[0] - if is_torch_dtype(arg): - # The first argument is a dtype - pass - elif isinstance(arg, str) or is_torch_device(arg) or isinstance(arg, int): - device = arg - else: - # it's something else - raise ValueError(f"Attempting to cast a BatchFeature to type {str(arg)}. This is not supported.") - - self.data = {k: _recursive_to(v, device, *args, **kwargs) for k, v in self.data.items()} - return self - - -# Copied from transformers.models.idefics2.image_processing_idefics2.make_list_of_images -def make_list_of_images(images: ImageInput) -> List[List[np.ndarray]]: - """ - Convert a single image or a list of images to a list of numpy arrays. - - Args: - images (`ImageInput`): - A single image or a list of images. - - Returns: - A list of numpy arrays. - """ - # If it's a single image, convert it to a list of lists - if is_valid_image(images): - images = [[images]] - # If it's a list of images, it's a single batch, so convert it to a list of lists - elif isinstance(images, (list, tuple)) and len(images) > 0 and is_valid_image(images[0]): - images = [images] - # If it's a list of batches, it's already in the right format - elif ( - isinstance(images, (list, tuple)) - and len(images) > 0 - and isinstance(images[0], (list, tuple)) - and len(images[0]) > 0 - and is_valid_image(images[0][0]) - ): - pass - else: - raise ValueError( - "Invalid input type. Must be a single image, a list of images, or a list of batches of images." - ) - return images - - # Adapted from function in image_transforms.py to ensure any transparent pixels are converted to white. def convert_to_rgb(image: ImageInput) -> ImageInput: """ @@ -219,18 +135,6 @@ def get_resize_output_image_size( return num_height_tokens * patch_height, num_width_tokens * patch_width -# Hack to get tensor conversion used in BatchFeature without batching the images -def _get_is_as_tensor_fns(tensor_type: Union[str, TensorType]) -> Tuple[Callable, Callable]: - return BatchFeature()._get_is_as_tensor_fns(tensor_type) - - -def convert_to_tensor(array, tensor_type: Union[str, TensorType]) -> Any: - is_tensor, as_tensor = _get_is_as_tensor_fns(tensor_type) - if is_tensor(array): - return array - return as_tensor(array) - - class PixtralImageProcessor(BaseImageProcessor): r""" Constructs a Pixtral image processor. @@ -368,6 +272,49 @@ def resize( **kwargs, ) + def _pad_for_batching( + self, + pixel_values: List[np.ndarray], + image_sizes: List[List[int]], + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ): + """ + Pads images on the `num_of_patches` dimension with zeros to form a batch of same number of patches. + Args: + pixel_values (`List[np.ndarray]`): + An array of pixel values of each images of shape (`batch_size`, `height`, `width`, `channels`) + image_sizes (`List[List[int]]`): + A list of sizes for each image in `pixel_values` in (height, width) format. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + If unset, will use same as the input image. + input_data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format for the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + If unset, will use the inferred format of the input image. + Returns: + List[`np.ndarray`]: The padded images. + """ + + max_shape = ( + max([size[0] for size in image_sizes]), + max([size[1] for size in image_sizes]), + ) + pixel_values = [ + pad( + image, + padding=((0, max_shape[0] - size[0]), (0, max_shape[1] - size[1])), + data_format=data_format, + input_data_format=input_data_format, + ) + for image, size in zip(pixel_values, image_sizes) + ] + return pixel_values + def preprocess( self, images: ImageInput, @@ -449,9 +396,9 @@ def preprocess( validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys) - images_list = make_list_of_images(images) + images = make_list_of_images(images) - if not valid_images(images_list[0]): + if not valid_images(images[0]): raise ValueError( "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " "torch.Tensor, tf.Tensor or jax.ndarray." @@ -469,12 +416,12 @@ def preprocess( ) if do_convert_rgb: - images_list = [[convert_to_rgb(image) for image in images] for images in images_list] + images = [convert_to_rgb(image) for image in images] # All transformations expect numpy arrays. - images_list = [[to_numpy_array(image) for image in images] for images in images_list] + images = [to_numpy_array(image) for image in images] - if do_rescale and is_scaled_image(images_list[0][0]): + if do_rescale and is_scaled_image(images[0]): logger.warning_once( "It looks like you are trying to rescale already rescaled images. If the input" " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." @@ -482,44 +429,43 @@ def preprocess( if input_data_format is None: # We assume that all images have the same channel dimension format. - input_data_format = infer_channel_dimension_format(images_list[0][0]) + input_data_format = infer_channel_dimension_format(images[0]) batch_images = [] batch_image_sizes = [] - for sample_images in images_list: - images = [] - image_sizes = [] - for image in sample_images: - if do_resize: - image = self.resize( - image=image, - size=size, - patch_size=patch_size, - resample=resample, - input_data_format=input_data_format, - ) - - if do_rescale: - image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) - - if do_normalize: - image = self.normalize( - image=image, mean=image_mean, std=image_std, input_data_format=input_data_format - ) - - images.append(image) - image_sizes.append(get_image_size(image, input_data_format)) - batch_images.append(images) - batch_image_sizes.append(image_sizes) - - images_list = [ - [to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images] - for images in batch_images - ] + for image in images: + if do_resize: + image = self.resize( + image=image, + size=size, + patch_size=patch_size, + resample=resample, + input_data_format=input_data_format, + ) + + if do_rescale: + image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) + + if do_normalize: + image = self.normalize( + image=image, mean=image_mean, std=image_std, input_data_format=input_data_format + ) + + image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) + + batch_images.append(image) + batch_image_sizes.append(get_image_size(image, data_format)) + + pixel_values = self._pad_for_batching( + pixel_values=batch_images, + image_sizes=batch_image_sizes, + input_data_format=data_format, + data_format=data_format, + ) - # Convert to tensor type outside of BatchFeature to avoid batching the images of different sizes - images_list = [[convert_to_tensor(image, return_tensors) for image in images] for images in images_list] - return BatchMixFeature(data={"pixel_values": images_list, "image_sizes": batch_image_sizes}, tensor_type=None) + return BatchFeature( + data={"pixel_values": pixel_values, "image_sizes": batch_image_sizes}, tensor_type=return_tensors + ) __all__ = ["PixtralImageProcessor"] diff --git a/src/transformers/models/pixtral/image_processing_pixtral_fast.py b/src/transformers/models/pixtral/image_processing_pixtral_fast.py index 082e255c8435..1013c6917671 100644 --- a/src/transformers/models/pixtral/image_processing_pixtral_fast.py +++ b/src/transformers/models/pixtral/image_processing_pixtral_fast.py @@ -16,7 +16,7 @@ from typing import Dict, List, Optional, Union -from ...image_processing_utils import get_size_dict +from ...image_processing_utils import BatchFeature, get_size_dict from ...image_processing_utils_fast import BaseImageProcessorFast from ...image_utils import ( ChannelDimension, @@ -26,6 +26,7 @@ get_image_size, get_image_type, infer_channel_dimension_format, + make_list_of_images, validate_fast_preprocess_arguments, validate_kwargs, ) @@ -38,10 +39,8 @@ logging, ) from .image_processing_pixtral import ( - BatchMixFeature, convert_to_rgb, get_resize_output_image_size, - make_list_of_images, ) @@ -189,6 +188,36 @@ def resize( **kwargs, ) + # Adapted from transformers.models.pixtral.image_processing_pixtral.PixtralImageProcessor._pad_for_batching + def _pad_for_batching( + self, + pixel_values: List[torch.Tensor], + image_sizes: List[List[int]], + ): + """ + Pads images on the `num_of_patches` dimension with zeros to form a batch of same number of patches. + Args: + pixel_values (`List[torch.Tensor]`): + An array of pixel values of each images of shape (`batch_size`, `channels`, `height`, `width`) + image_sizes (`List[List[int]]`): + A list of sizes for each image in `pixel_values` in (height, width) format. + Returns: + List[`torch.Tensor`]: The padded images. + """ + + max_shape = ( + max([size[0] for size in image_sizes]), + max([size[1] for size in image_sizes]), + ) + pixel_values = [ + torch.nn.functional.pad( + image, + pad=(0, max_shape[1] - size[1], 0, max_shape[0] - size[0]), + ) + for image, size in zip(pixel_values, image_sizes) + ] + return torch.stack(pixel_values) + def preprocess( self, images: ImageInput, @@ -206,7 +235,7 @@ def preprocess( data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, input_data_format: Optional[Union[str, ChannelDimension]] = None, **kwargs, - ) -> BatchMixFeature: + ) -> BatchFeature: """ Preprocess an image or batch of images. @@ -271,8 +300,8 @@ def preprocess( validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys) - images_list = make_list_of_images(images) - image_type = get_image_type(images_list[0][0]) + images = make_list_of_images(images) + image_type = get_image_type(images[0]) if image_type not in [ImageType.PIL, ImageType.TORCH, ImageType.NUMPY]: raise ValueError(f"Unsupported input image type {image_type}") @@ -290,65 +319,63 @@ def preprocess( data_format=data_format, ) - if do_convert_rgb: - images_list = [[convert_to_rgb(image) for image in images] for images in images_list] - - if image_type == ImageType.PIL: - images_list = [[F.pil_to_tensor(image) for image in images] for images in images_list] - elif image_type == ImageType.NUMPY: - # not using F.to_tensor as it doesn't handle (C, H, W) numpy arrays - images_list = [[torch.from_numpy(image).contiguous() for image in images] for images in images_list] - - if device is not None: - images_list = [[image.to(device) for image in images] for images in images_list] - - # We assume that all images have the same channel dimension format. - if input_data_format is None: - input_data_format = infer_channel_dimension_format(images_list[0][0]) - if input_data_format == ChannelDimension.LAST: - images_list = [[image.permute(2, 0, 1).contiguous() for image in images] for images in images_list] - input_data_format = ChannelDimension.FIRST - if do_rescale and do_normalize: # fused rescale and normalize - new_mean = torch.tensor(image_mean, device=images_list[0][0].device) * (1.0 / rescale_factor) - new_std = torch.tensor(image_std, device=images_list[0][0].device) * (1.0 / rescale_factor) + new_mean = torch.tensor(image_mean, device=device) * (1.0 / rescale_factor) + new_std = torch.tensor(image_std, device=device) * (1.0 / rescale_factor) batch_images = [] batch_image_sizes = [] - for sample_images in images_list: - images = [] - image_sizes = [] - for image in sample_images: - if do_resize: - interpolation = ( - pil_torch_interpolation_mapping[resample] - if isinstance(resample, (PILImageResampling, int)) - else resample - ) - image = self.resize( - image=image, - size=size, - patch_size=patch_size, - interpolation=interpolation, - ) - - if do_rescale and do_normalize: - # fused rescale and normalize - image = F.normalize(image.to(dtype=torch.float32), new_mean, new_std) - elif do_rescale: - image = image * rescale_factor - elif do_normalize: - image = F.normalize(image, image_mean, image_std) - - images.append(image) - image_sizes.append(get_image_size(image, input_data_format)) - batch_images.append(images) - batch_image_sizes.append(image_sizes) - - return BatchMixFeature( - data={"pixel_values": batch_images, "image_sizes": batch_image_sizes}, - tensor_type=None, + for image in images: + if do_convert_rgb: + image = convert_to_rgb(image) + + if image_type == ImageType.PIL: + image = F.pil_to_tensor(image) + elif image_type == ImageType.NUMPY: + # not using F.to_tensor as it doesn't handle (C, H, W) numpy arrays + image = torch.from_numpy(image).contiguous() + + # We assume that all images have the same channel dimension format. + if input_data_format is None: + input_data_format = infer_channel_dimension_format(image) + + if input_data_format == ChannelDimension.LAST: + image = image.permute(2, 0, 1).contiguous() + + image = image.to(device) + + if do_resize: + interpolation = ( + pil_torch_interpolation_mapping[resample] + if isinstance(resample, (PILImageResampling, int)) + else resample + ) + image = self.resize( + image=image, + size=size, + patch_size=patch_size, + interpolation=interpolation, + ) + + if do_rescale and do_normalize: + # fused rescale and normalize + image = F.normalize(image.to(dtype=torch.float32), new_mean, new_std) + elif do_rescale: + image = image * rescale_factor + elif do_normalize: + image = F.normalize(image, image_mean, image_std) + + batch_images.append(image) + batch_image_sizes.append(get_image_size(image, ChannelDimension.FIRST)) + + pixel_values = self._pad_for_batching( + pixel_values=batch_images, + image_sizes=batch_image_sizes, + ) + + return BatchFeature( + data={"pixel_values": pixel_values, "image_sizes": batch_image_sizes}, tensor_type=return_tensors ) diff --git a/src/transformers/models/pixtral/modeling_pixtral.py b/src/transformers/models/pixtral/modeling_pixtral.py index 905eef22ca3d..af41bab84259 100644 --- a/src/transformers/models/pixtral/modeling_pixtral.py +++ b/src/transformers/models/pixtral/modeling_pixtral.py @@ -14,7 +14,7 @@ # limitations under the License. """PyTorch Pixtral model.""" -from typing import List, Optional, Tuple, Union +from typing import Optional, Tuple, Union import torch import torch.utils.checkpoint @@ -57,7 +57,7 @@ class PixtralRotaryEmbedding(nn.Module): a corresponding positional embedding, based on its index in the grid. """ - def __init__(self, config, device): + def __init__(self, config, device=None): super().__init__() self.rope_type = "default" self.dim = config.head_dim @@ -89,7 +89,6 @@ def forward(self, x, position_ids): # Core RoPE block freqs = self.inv_freq[position_ids] - # position_ids_expanded = position_ids[:, None, :].float() # Force float32 (see https://github.com/huggingface/transformers/pull/29285) device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" @@ -175,7 +174,7 @@ def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, - position_embeddings: Optional[torch.Tensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, output_attentions: Optional[bool] = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """Input shape: Batch x Time x Channel""" @@ -261,8 +260,8 @@ def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, - position_embeddings: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = False, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + output_attentions: Optional[bool] = None, ) -> Tuple[torch.FloatTensor]: """ Args: @@ -310,7 +309,7 @@ def forward( self, inputs_embeds, attention_mask: Optional[torch.Tensor] = None, - position_embeddings: Optional[torch.Tensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, @@ -375,7 +374,7 @@ def forward( if not return_dict: return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) return BaseModelOutput( - last_hidden_state=hidden_states, hidden_states=[hidden_states], attentions=all_attentions + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions ) @@ -399,10 +398,9 @@ def forward( class PixtralPreTrainedModel(PreTrainedModel): config_class = PixtralVisionConfig base_model_prefix = "model" + main_input_name = "pixel_values" supports_gradient_checkpointing = True - _no_split_modules = ["PixtralVisionAttention"] - _skip_keys_device_placement = "past_key_values" - _supports_cache_class = True + _no_split_modules = ["PixtralAttentionLayer"] def _init_weights(self, module): std = ( @@ -426,6 +424,8 @@ def _init_weights(self, module): pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`AutoImageProcessor.__call__`] for details. + image_sizes (`torch.LongTensor` of shape `(batch_size, 2)`, *optional*): + The sizes of the images in the batch, being (height, width) for each image. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. @@ -470,15 +470,22 @@ def __init__(self, config): stride=config.patch_size, bias=False, ) + self.patch_size = config.patch_size self.ln_pre = PixtralRMSNorm(config.hidden_size, eps=1e-5) self.transformer = PixtralTransformer(config) - self.patch_positional_embedding = PixtralRotaryEmbedding(config, device=self.device) + self.patch_positional_embedding = PixtralRotaryEmbedding(config) + + self.post_init() + + def get_input_embeddings(self): + return self.patch_conv @add_start_docstrings_to_model_forward(PIXTRAL_INPUTS_DOCSTRING) def forward( self, - pixel_values: List[torch.Tensor], - output_hidden_states: Optional[bool] = False, + pixel_values: torch.Tensor, + image_sizes: torch.Tensor, + output_hidden_states: Optional[bool] = None, output_attentions: Optional[bool] = None, return_dict: Optional[bool] = None, *args, @@ -490,24 +497,36 @@ def forward( all tokens of all images of shape (N_toks, D) """ # pass images through initial convolution independently - if len(pixel_values) > 1: - raise ValueError("Batching/padding not supported yet!") - patch_embeds_list = [self.patch_conv(img.to(self.dtype)) for sample in pixel_values for img in sample] + patch_embeds = self.patch_conv(pixel_values) + patch_embeds_list = [ + embed[..., : (size[0] // self.patch_size), : (size[1] // self.patch_size)] + for embed, size in zip(patch_embeds, image_sizes) + ] # flatten to a single sequence patch_embeds = torch.cat([p.flatten(1).T for p in patch_embeds_list], dim=0).unsqueeze(0) patch_embeds = self.ln_pre(patch_embeds) + # positional embeddings position_ids = position_ids_in_meshgrid( patch_embeds_list, max_width=self.config.image_size // self.config.patch_size - ).to(self.device) - - position_embedding = self.patch_positional_embedding(patch_embeds, position_ids) + ) + position_embeddings = self.patch_positional_embedding(patch_embeds, position_ids) attention_mask = generate_block_attention_mask( [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], patch_embeds ) - return self.transformer(patch_embeds, attention_mask, position_embedding) + + out = self.transformer( + patch_embeds, + attention_mask=attention_mask, + position_embeddings=position_embeddings, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + return_dict=return_dict, + ) + + return out __all__ = ["PixtralVisionModel", "PixtralPreTrainedModel"] diff --git a/src/transformers/models/pixtral/processing_pixtral.py b/src/transformers/models/pixtral/processing_pixtral.py index e60151130ae0..aea6375f78bc 100644 --- a/src/transformers/models/pixtral/processing_pixtral.py +++ b/src/transformers/models/pixtral/processing_pixtral.py @@ -22,7 +22,7 @@ from ...image_utils import ImageInput, is_valid_image, load_image from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, _validate_images_text_input_order from ...tokenization_utils_base import PreTokenizedInput, TextInput -from ...utils import is_torch_device, is_torch_dtype, logging, requires_backends +from ...utils import logging logger = logging.get_logger(__name__) @@ -50,58 +50,6 @@ def is_image_or_image_url(elem): return is_url(elem) or is_valid_image(elem) -# Copied from transformers.models.pixtral.image_processing_pixtral.BatchMixFeature -class BatchMixFeature(BatchFeature): - def to(self, *args, **kwargs) -> "BatchMixFeature": - """ - Send all values to device by calling `v.to(*args, **kwargs)` (PyTorch only). This should support casting in - different `dtypes` and sending the `BatchFeature` to a different `device`. - - Args: - args (`Tuple`): - Will be passed to the `to(...)` function of the tensors. - kwargs (`Dict`, *optional*): - Will be passed to the `to(...)` function of the tensors. - - Returns: - [`BatchFeature`]: The same instance after modification. - """ - - def _recursive_to(obj, device, *args, **kwargs): - # Lists can be nested, so keep digging until we hit tensors - if isinstance(obj, list): - return [_recursive_to(o, device, *args, **kwargs) for o in obj] - # We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor` - elif isinstance(obj, torch.Tensor) and torch.is_floating_point(obj): - # cast and send to device - return obj.to(*args, **kwargs) - elif isinstance(obj, torch.Tensor) and device is not None: - # only send to device, don't cast - return obj.to(device=device) - else: - return obj - - requires_backends(self, ["torch"]) - import torch # noqa - - device = kwargs.get("device") - # Check if the args are a device or a dtype - if device is None and len(args) > 0: - # device should be always the first argument - arg = args[0] - if is_torch_dtype(arg): - # The first argument is a dtype - pass - elif isinstance(arg, str) or is_torch_device(arg) or isinstance(arg, int): - device = arg - else: - # it's something else - raise ValueError(f"Attempting to cast a BatchFeature to type {str(arg)}. This is not supported.") - - self.data = {k: _recursive_to(v, device, *args, **kwargs) for k, v in self.data.items()} - return self - - class PixtralProcessor(ProcessorMixin): r""" Constructs a Pixtral processor which wraps a Pixtral image processor and a Pixtral tokenizer into a single processor. @@ -161,7 +109,7 @@ def __call__( audio=None, videos=None, **kwargs: Unpack[PixtralProcessorKwargs], - ) -> BatchMixFeature: + ) -> BatchFeature: """ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to encode @@ -205,28 +153,16 @@ def __call__( if images is not None: if is_image_or_image_url(images): - if isinstance(text, str) or isinstance(text, list) and len(text) == 1: - # If there's a single sample, the image must belong to it - images = [[images]] - else: - raise ValueError( - "You have supplied multiple text samples, but `images` is not a nested list. When processing multiple samples, `images` should be a list of lists of images, one list per sample." - ) + images = [images] elif isinstance(images, list) and is_image_or_image_url(images[0]): - if isinstance(text, str) or isinstance(text, list) and len(text) == 1: - # If there's a single sample, all images must belong to it - images = [images] - else: - raise ValueError( - "You have supplied multiple text samples, but `images` is not a nested list. When processing multiple samples, `images` should be a list of lists of images, one list per sample." - ) - elif isinstance(images, list) and isinstance(images[0], list) and is_image_or_image_url(images[0][0]): pass + elif isinstance(images, list) and isinstance(images[0], list) and is_image_or_image_url(images[0][0]): + images = [image for sublist in images for image in sublist] else: raise ValueError( "Invalid input images. Please provide a single image, a list of images, or a list of lists of images." ) - images = [[load_image(im) for im in sample] for sample in images] + images = [load_image(im) if isinstance(im, str) else im for im in images] image_inputs = self.image_processor(images, patch_size=self.patch_size, **output_kwargs["images_kwargs"]) else: image_inputs = {} @@ -240,15 +176,13 @@ def __call__( prompt_strings = text if image_inputs.get("pixel_values") is not None: # Replace the image token with the expanded image token sequence - images = image_inputs["pixel_values"] - image_sizes = image_inputs.pop("image_sizes") + image_sizes = iter(image_inputs["image_sizes"]) prompt_strings = [] + replace_strings = [] - for sample_images, sample_image_sizes, sample in zip(images, image_sizes, text): - replace_strings = [] - # First calculate the number of tokens needed for each image and put in a placeholder - for image, image_size in zip(sample_images, sample_image_sizes): - height, width = image_size + for sample in text: + while self.image_token in sample: + height, width = next(image_sizes) num_height_tokens = height // self.patch_size num_width_tokens = width // self.patch_size replace_tokens = [ @@ -267,7 +201,9 @@ def __call__( prompt_strings.append(sample) text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"]) - return BatchMixFeature(data={**text_inputs, **image_inputs}) + return BatchFeature( + data={**text_inputs, **image_inputs}, tensor_type=output_kwargs["common_kwargs"]["return_tensors"] + ) # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama def batch_decode(self, *args, **kwargs): diff --git a/tests/models/llava/test_modeling_llava.py b/tests/models/llava/test_modeling_llava.py index 23663ee649a7..5e19a74f221a 100644 --- a/tests/models/llava/test_modeling_llava.py +++ b/tests/models/llava/test_modeling_llava.py @@ -564,9 +564,8 @@ def test_generation_siglip_backbone(self): self.assertTrue(processor.batch_decode(output, skip_special_tokens=True)[0] == EXPECTED_DECODED_TEXT) @slow - @require_bitsandbytes def test_pixtral(self): - model_id = "hf-internal-testing/pixtral-12b" + model_id = "mistral-community/pixtral-12b" model = LlavaForConditionalGeneration.from_pretrained(model_id) processor = AutoProcessor.from_pretrained(model_id) @@ -579,33 +578,75 @@ def test_pixtral(self): PROMPT = "[INST]Describe the images.\n[IMG][IMG][IMG][IMG][/INST]" # image = Image.open(requests.get(url, stream=True).raw) - inputs = processor(text=PROMPT, images=IMG_URLS, return_tensors="pt").to("cuda") + inputs = processor(text=PROMPT, images=IMG_URLS, return_tensors="pt").to(model.device) generate_ids = model.generate(**inputs, max_new_tokens=500) ouptut = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + print(ouptut) # fmt: off EXPECTED_GENERATION = """ Describe the images. -Sure, let's break down each image description: +Certainly! Here are the descriptions of the images: -1. **Image 1:** - - **Description:** A black dog with a glossy coat is sitting on a wooden floor. The dog has a focused expression and is looking directly at the camera. - - **Details:** The wooden floor has a rustic appearance with visible wood grain patterns. The dog's eyes are a striking color, possibly brown or amber, which contrasts with its black fur. +1. **Image 1**: This image features a black dog with a glossy coat sitting on a wooden surface. The dog has a calm and attentive expression, looking directly at the camera. The wooden background has a rustic appearance with visible grain and texture. -2. **Image 2:** - - **Description:** A scenic view of a mountainous landscape with a winding road cutting through it. The road is surrounded by lush green vegetation and leads to a distant valley. - - **Details:** The mountains are rugged with steep slopes, and the sky is clear, indicating good weather. The winding road adds a sense of depth and perspective to the image. +2. **Image 2**: This image captures a breathtaking view of a mountainous landscape. The mountains are rugged and covered with patches of green vegetation. The sky above is clear, and the scene conveys a sense of tranquility and natural beauty. -3. **Image 3:** - - **Description:** A beach scene with waves crashing against the shore. There are several people in the water and on the beach, enjoying the waves and the sunset. - - **Details:** The waves are powerful, creating a dynamic and lively atmosphere. The sky is painted with hues of orange and pink from the setting sun, adding a warm glow to the scene. +3. **Image 3**: This image shows a beach scene during sunset. The waves are gently rolling onto the shore, and several people can be seen in the water, possibly surfing or swimming. The sky is painted with warm hues of orange and yellow, creating a serene and picturesque atmosphere. -4. **Image 4:** - - **Description:** A garden path leading to a large tree with a bench underneath it. The path is bordered by well-maintained grass and flowers. - - **Details:** The path is made of small stones or gravel, and the tree provides a shaded area with the bench invitingly placed beneath it. The surrounding area is lush and green, suggesting a well-kept garden. +4. **Image 4**: This image depicts a narrow, winding path that cuts through a lush, green landscape. On either side of the path, there is dense grass and various trees, including a prominent tree with white blossoms. The sky is clear and blue, adding to the peaceful and inviting ambiance of the scene. -Each image captures a different scene, from a close-up of a dog to expansive natural landscapes, showcasing various elements of nature and human interaction with it. +These descriptions provide a detailed overview of the content and atmosphere of each image. """ # fmt: on # check that both inputs are handled correctly and generate the same output - self.assertListEqual(ouptut, EXPECTED_GENERATION) + self.assertEqual(ouptut, EXPECTED_GENERATION) + + @slow + @require_bitsandbytes + def test_pixtral_4bit(self): + model_id = "mistral-community/pixtral-12b" + model = LlavaForConditionalGeneration.from_pretrained(model_id, load_in_4bit=True) + processor = AutoProcessor.from_pretrained(model_id) + + IMG_URLS = [ + Image.open(requests.get("https://picsum.photos/id/237/400/300", stream=True).raw), + Image.open(requests.get("https://picsum.photos/id/231/200/300", stream=True).raw), + ] + PROMPT = "[INST][IMG][IMG]Describe the images.[/INST]" + + inputs = processor(text=PROMPT, images=IMG_URLS, return_tensors="pt").to(torch_device, torch.float16) + generate_ids = model.generate(**inputs, max_new_tokens=50) + output = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + + EXPECTED_GENERATION = "Describe the images.The image showcases a dog, which is prominently positioned in the center, taking up a significant portion of the frame. The dog is situated against a backdrop of a wooden surface, which spans the entire image. The dog appears to be a black Labrador" # fmt: skip + self.assertEqual(output, EXPECTED_GENERATION) + + @slow + @require_bitsandbytes + def test_pixtral_batched(self): + model_id = "mistral-community/pixtral-12b" + model = LlavaForConditionalGeneration.from_pretrained(model_id, load_in_4bit=True) + processor = AutoProcessor.from_pretrained(model_id) + processor.tokenizer.pad_token_id = processor.tokenizer.eos_token_id + + IMG_URLS = [ + Image.open(requests.get("https://picsum.photos/id/237/400/300", stream=True).raw), + Image.open(requests.get("https://picsum.photos/id/17/150/500", stream=True).raw), + ] + PROMPT = [ + "[INST][IMG]What breed is the dog?[/INST]", + "[INST][IMG]What is shown in this image?[/INST]", + ] + + inputs = processor(text=PROMPT, images=IMG_URLS, padding=True, return_tensors="pt").to( + torch_device, torch.float16 + ) + generate_ids = model.generate(**inputs, max_new_tokens=50) + output = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) + + EXPECTED_GENERATION = [ + 'What breed is the dog?The dog in the image is a black Labrador Retriever.', + 'What is shown in this image?The image depicts a narrow, winding dirt path surrounded by lush greenery. The path is flanked by grass and shrubs on both sides. On the left side, there are tall trees and dense foliage, while on the right side, there' + ] # fmt: skip + self.assertEqual(output, EXPECTED_GENERATION) diff --git a/tests/models/pixtral/test_image_processing_pixtral.py b/tests/models/pixtral/test_image_processing_pixtral.py index 19bfde038f2a..cc3fbba3d275 100644 --- a/tests/models/pixtral/test_image_processing_pixtral.py +++ b/tests/models/pixtral/test_image_processing_pixtral.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import random import time import unittest @@ -92,49 +91,47 @@ def prepare_image_processor_dict(self): "do_convert_rgb": self.do_convert_rgb, } - def expected_output_image_shape(self, image): - if isinstance(image, Image.Image): - width, height = image.size - elif isinstance(image, np.ndarray): - height, width = image.shape[:2] - elif isinstance(image, torch.Tensor): - height, width = image.shape[-2:] + def expected_output_image_shape(self, images): + if not isinstance(images, (list, tuple)): + images = [images] - max_height = max_width = self.size.get("longest_edge") + batch_size = len(images) + return_height, return_width = 0, 0 + for image in images: + if isinstance(image, Image.Image): + width, height = image.size + elif isinstance(image, np.ndarray): + height, width = image.shape[:2] + elif isinstance(image, torch.Tensor): + height, width = image.shape[-2:] - ratio = max(height / max_height, width / max_width) - if ratio > 1: - height = int(np.ceil(height / ratio)) - width = int(np.ceil(width / ratio)) + max_height = max_width = self.size.get("longest_edge") - patch_height, patch_width = self.patch_size["height"], self.patch_size["width"] - num_height_tokens = (height - 1) // patch_height + 1 - num_width_tokens = (width - 1) // patch_width + 1 + ratio = max(height / max_height, width / max_width) + if ratio > 1: + height = int(np.ceil(height / ratio)) + width = int(np.ceil(width / ratio)) - height = num_height_tokens * patch_height - width = num_width_tokens * patch_width + patch_height, patch_width = self.patch_size["height"], self.patch_size["width"] + num_height_tokens = (height - 1) // patch_height + 1 + num_width_tokens = (width - 1) // patch_width + 1 - return self.num_channels, height, width + return_height = max(num_height_tokens * patch_height, return_height) + return_width = max(num_width_tokens * patch_width, return_width) + + return batch_size, self.num_channels, return_height, return_width def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False): - # Use prepare_image_inputs to make a list of list of single images - - images_list = [] - for _ in range(self.batch_size): - images = [] - for _ in range(random.randint(1, self.max_num_images_per_sample)): - img = prepare_image_inputs( - batch_size=1, - num_channels=self.num_channels, - min_resolution=self.min_resolution, - max_resolution=self.max_resolution, - equal_resolution=equal_resolution, - numpify=numpify, - torchify=torchify, - )[0] - images.append(img) - images_list.append(images) - return images_list + images = prepare_image_inputs( + batch_size=self.batch_size, + num_channels=self.num_channels, + min_resolution=self.min_resolution, + max_resolution=self.max_resolution, + equal_resolution=equal_resolution, + numpify=numpify, + torchify=torchify, + ) + return images @require_torch @@ -173,23 +170,18 @@ def test_call_pil(self): image_processing = image_processing_class(**self.image_processor_dict) # create random PIL images image_inputs_list = self.image_processor_tester.prepare_image_inputs() - for image_inputs in image_inputs_list: - for image in image_inputs: - self.assertIsInstance(image, Image.Image) + for image in image_inputs_list: + self.assertIsInstance(image, Image.Image) # Test not batched input - encoded_images = image_processing(image_inputs_list[0][0], return_tensors="pt").pixel_values - expected_output_image_shape = self.image_processor_tester.expected_output_image_shape( - image_inputs_list[0][0] - ) - self.assertEqual(tuple(encoded_images[0][0].shape), expected_output_image_shape) + encoded_images = image_processing(image_inputs_list[0], return_tensors="pt").pixel_values + expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs_list[0]) + self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) # Test batched - batch_encoded_images = image_processing(image_inputs_list, return_tensors="pt").pixel_values - for encoded_images, images in zip(batch_encoded_images, image_inputs_list): - for encoded_image, image in zip(encoded_images, images): - expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image) - self.assertEqual(tuple(encoded_image.shape), expected_output_image_shape) + encoded_images = image_processing(image_inputs_list, return_tensors="pt").pixel_values + expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs_list) + self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) def test_call_numpy(self): for image_processing_class in self.image_processor_list: @@ -197,23 +189,18 @@ def test_call_numpy(self): image_processing = image_processing_class(**self.image_processor_dict) # create random numpy tensors image_inputs_list = self.image_processor_tester.prepare_image_inputs(numpify=True) - for image_inputs in image_inputs_list: - for image in image_inputs: - self.assertIsInstance(image, np.ndarray) + for image in image_inputs_list: + self.assertIsInstance(image, np.ndarray) # Test not batched input - encoded_images = image_processing(image_inputs_list[0][0], return_tensors="pt").pixel_values - expected_output_image_shape = self.image_processor_tester.expected_output_image_shape( - image_inputs_list[0][0] - ) - self.assertEqual(tuple(encoded_images[0][0].shape), expected_output_image_shape) + encoded_images = image_processing(image_inputs_list[0], return_tensors="pt").pixel_values + expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs_list[0]) + self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) # Test batched batch_encoded_images = image_processing(image_inputs_list, return_tensors="pt").pixel_values - for encoded_images, images in zip(batch_encoded_images, image_inputs_list): - for encoded_image, image in zip(encoded_images, images): - expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image) - self.assertEqual(tuple(encoded_image.shape), expected_output_image_shape) + expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs_list) + self.assertEqual(tuple(batch_encoded_images.shape), expected_output_image_shape) def test_call_pytorch(self): for image_processing_class in self.image_processor_list: @@ -221,23 +208,18 @@ def test_call_pytorch(self): image_processing = image_processing_class(**self.image_processor_dict) # create random PyTorch tensors image_inputs_list = self.image_processor_tester.prepare_image_inputs(torchify=True) - for image_inputs in image_inputs_list: - for image in image_inputs: - self.assertIsInstance(image, torch.Tensor) + for image in image_inputs_list: + self.assertIsInstance(image, torch.Tensor) # Test not batched input - encoded_images = image_processing(image_inputs_list[0][0], return_tensors="pt").pixel_values - expected_output_image_shape = self.image_processor_tester.expected_output_image_shape( - image_inputs_list[0][0] - ) - self.assertEqual(tuple(encoded_images[0][0].shape), expected_output_image_shape) + encoded_images = image_processing(image_inputs_list[0], return_tensors="pt").pixel_values + expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs_list[0]) + self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) # Test batched batch_encoded_images = image_processing(image_inputs_list, return_tensors="pt").pixel_values - for encoded_images, images in zip(batch_encoded_images, image_inputs_list): - for encoded_image, image in zip(encoded_images, images): - expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image) - self.assertEqual(tuple(encoded_image.shape), expected_output_image_shape) + expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs_list) + self.assertEqual(tuple(batch_encoded_images.shape), expected_output_image_shape) @require_vision @require_torch diff --git a/tests/models/pixtral/test_modeling_pixtral.py b/tests/models/pixtral/test_modeling_pixtral.py index 3e5667caf45e..f254d9eecd04 100644 --- a/tests/models/pixtral/test_modeling_pixtral.py +++ b/tests/models/pixtral/test_modeling_pixtral.py @@ -74,15 +74,17 @@ def __init__( self.initializer_range = initializer_range self.scope = scope - # in ViT, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token) - num_patches = (image_size // patch_size) ** 2 - self.seq_length = num_patches + 1 + # in Pixtral, the seq length equals the number of patches * batch_size because the patches are flattened + self.seq_length = (image_size // patch_size) ** 2 * batch_size def prepare_config_and_inputs(self): pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) + image_sizes = torch.tensor( + [[self.image_size, self.image_size]] * self.batch_size, dtype=torch.long, device=torch_device + ) config = self.get_config() - return config, pixel_values + return config, pixel_values, image_sizes def get_config(self): return PixtralVisionConfig( @@ -127,8 +129,8 @@ def create_and_check_model_with_projection(self, config, pixel_values): def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() - config, pixel_values = config_and_inputs - inputs_dict = {"pixel_values": pixel_values} + config, pixel_values, image_sizes = config_and_inputs + inputs_dict = {"pixel_values": pixel_values, "image_sizes": image_sizes} return config, inputs_dict @@ -142,113 +144,17 @@ class PixtralVisionModelModelTest(ModelTesterMixin, unittest.TestCase): test_pruning = False test_head_masking = False test_torchscript = False + test_resize_embeddings = False def setUp(self): self.model_tester = PixtralVisionModelTester(self) self.config_tester = ConfigTester(self, config_class=PixtralVisionConfig, has_text_modality=False) - @unittest.skip("model does not support input embeds") - def test_inputs_embeds(self): - pass - - @unittest.skip("model does not support input embeds") - def test_inputs_embeds_matches_input_ids(self): - pass - - @unittest.skip( - reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" - ) - def test_training_gradient_checkpointing(self): - pass - - @unittest.skip( - reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" - ) - def test_training_gradient_checkpointing_use_reentrant(self): - pass - - @unittest.skip( - reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" - ) - def test_training_gradient_checkpointing_use_reentrant_false(self): - pass - - @unittest.skip(reason="Compile not yet supported because in Pixtral models") - def test_sdpa_can_compile_dynamic(self): - pass - - @unittest.skip(reason="Compile not yet supported because in Pixtral models") - def test_sdpa_can_dispatch_on_flash(self): - pass - - @unittest.skip(reason="Not supported yet") - def test_attention_outputs(self): - pass - - @unittest.skip(reason="Not supported yet") - def test_cpu_offload(self): - pass - - @unittest.skip(reason="Not supported yet") - def test_batching_equivalence(self): - pass - - @unittest.skip(reason="Not supported yet") - def test_disk_offload_bin(self): - pass - - @unittest.skip(reason="Not supported yet") - def test_retain_grad_hidden_states_attentions(self): - pass - - @unittest.skip(reason="Not supported yet") - def test_multi_gpu_data_parallel_forward(self): - pass - - @unittest.skip(reason="Not supported yet") - def test_model_parallelism(self): - pass - - @unittest.skip(reason="Not supported yet") - def test_model_outputs_equivalence(self): - pass - - @unittest.skip(reason="Not supported yet") - def test_save_load(self): - pass - - @unittest.skip(reason="Not supported yet") def test_model_get_set_embeddings(self): - pass - - @unittest.skip(reason="Not supported yet") - def test_resize_tokens_embeddings(self): - pass - - @unittest.skip(reason="Not supported yet") - def test_model_main_input_name(self): - pass - - @unittest.skip(reason="Not supported yet") - def test_initialization(self): - pass - - @unittest.skip(reason="Not supported yet") - def test_hidden_states_output(self): - pass - - @unittest.skip(reason="Not supported yet") - def test_gradient_checkpointing_backward_compatibility(self): - pass - - @unittest.skip(reason="Not supported yet") - def test_feed_forward_chunking(self): - pass - - @unittest.skip(reason="Not supported yet") - def test_disk_offload_safetensors(self): - pass + config, _ = self.model_tester.prepare_config_and_inputs_for_common() - @unittest.skip(reason="Not supported yet") - def test_determinism(self): - pass + for model_class in self.all_model_classes: + model = model_class(config) + self.assertIsInstance(model.get_input_embeddings(), (torch.nn.Module)) + x = model.get_output_embeddings() + self.assertTrue(x is None or isinstance(x, torch.nn.Linear)) diff --git a/tests/models/pixtral/test_processor_pixtral.py b/tests/models/pixtral/test_processor_pixtral.py index d224c531241f..a678e7c0102c 100644 --- a/tests/models/pixtral/test_processor_pixtral.py +++ b/tests/models/pixtral/test_processor_pixtral.py @@ -14,7 +14,6 @@ import shutil import tempfile import unittest -from typing import Optional import requests import torch @@ -28,7 +27,7 @@ if is_vision_available(): from PIL import Image - from transformers import AutoTokenizer, PixtralImageProcessor, PixtralProcessor + from transformers import PixtralProcessor @require_vision @@ -46,20 +45,15 @@ def setUpClass(cls): def setUp(self): self.tmpdirname = tempfile.mkdtemp() - - # FIXME - just load the processor directly from the checkpoint - tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/pixtral-12b") - image_processor = PixtralImageProcessor() - processor = PixtralProcessor(tokenizer=tokenizer, image_processor=image_processor) + processor = PixtralProcessor.from_pretrained("mistral-community/pixtral-12b") processor.save_pretrained(self.tmpdirname) def tearDown(self): shutil.rmtree(self.tmpdirname) - @unittest.skip("No chat template was set for this model (yet)") def test_chat_template(self): processor = self.processor_class.from_pretrained(self.tmpdirname) - expected_prompt = "USER: [IMG]\nWhat is shown in this image? ASSISTANT:" + expected_prompt = "[INST][IMG]What is shown in this image?[/INST]" messages = [ { @@ -73,13 +67,12 @@ def test_chat_template(self): formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True) self.assertEqual(expected_prompt, formatted_prompt) - @unittest.skip("No chat template was set for this model (yet)") def test_image_token_filling(self): processor = self.processor_class.from_pretrained(self.tmpdirname) # Important to check with non square image image = torch.randint(0, 2, (3, 500, 316)) - expected_image_tokens = 1526 - image_token_index = 32000 + expected_image_tokens = 640 + image_token_index = 10 messages = [ { @@ -111,11 +104,8 @@ def test_processor_with_single_image(self): self.assertIn("input_ids", inputs_image) self.assertTrue(len(inputs_image["input_ids"]) == 1) self.assertIsInstance(inputs_image["input_ids"], torch.Tensor) - self.assertIsInstance(inputs_image["pixel_values"], list) - self.assertTrue(len(inputs_image["pixel_values"]) == 1) - self.assertIsInstance(inputs_image["pixel_values"][0], list) - self.assertTrue(len(inputs_image["pixel_values"][0]) == 1) - self.assertIsInstance(inputs_image["pixel_values"][0][0], torch.Tensor) + self.assertIsInstance(inputs_image["pixel_values"], torch.Tensor) + self.assertTrue(inputs_image["pixel_values"].shape == torch.Size([1, 3, 32, 32])) # fmt: off input_ids = inputs_image["input_ids"] @@ -131,11 +121,8 @@ def test_processor_with_single_image(self): self.assertIn("input_ids", inputs_url) self.assertTrue(len(inputs_url["input_ids"]) == 1) self.assertIsInstance(inputs_url["input_ids"], torch.Tensor) - self.assertIsInstance(inputs_url["pixel_values"], list) - self.assertTrue(len(inputs_url["pixel_values"]) == 1) - self.assertIsInstance(inputs_url["pixel_values"][0], list) - self.assertTrue(len(inputs_url["pixel_values"][0]) == 1) - self.assertIsInstance(inputs_url["pixel_values"][0][0], torch.Tensor) + self.assertIsInstance(inputs_image["pixel_values"], torch.Tensor) + self.assertTrue(inputs_image["pixel_values"].shape == torch.Size([1, 3, 32, 32])) # fmt: off input_ids = inputs_url["input_ids"] @@ -146,6 +133,28 @@ def test_processor_with_single_image(self): ) # fmt: on + # Test passing inputs as a single list + inputs_image = processor(text=prompt_string, images=[self.image_0], return_tensors="pt") + self.assertTrue(inputs_image["pixel_values"].shape == torch.Size([1, 3, 32, 32])) + + # fmt: off + self.assertEqual( + inputs_image["input_ids"][0].tolist(), + [21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 4701, 1307, 1278, 3937, 1063, 1349, 4290, 16002, 41150, 1058] + ) + # fmt: on + + # Test as nested single list + inputs_image = processor(text=prompt_string, images=[[self.image_0]], return_tensors="pt") + self.assertTrue(inputs_image["pixel_values"].shape == torch.Size([1, 3, 32, 32])) + + # fmt: off + self.assertEqual( + inputs_image["input_ids"][0].tolist(), + [21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 4701, 1307, 1278, 3937, 1063, 1349, 4290, 16002, 41150, 1058] + ) + # fmt: on + def test_processor_with_multiple_images_single_list(self): processor = self.processor_class.from_pretrained(self.tmpdirname) prompt_string = "USER: [IMG][IMG]\nWhat's the difference between these two images? ASSISTANT:" @@ -159,11 +168,8 @@ def test_processor_with_multiple_images_single_list(self): self.assertIn("input_ids", inputs_image) self.assertTrue(len(inputs_image["input_ids"]) == 1) self.assertIsInstance(inputs_image["input_ids"], torch.Tensor) - self.assertIsInstance(inputs_image["pixel_values"], list) - self.assertTrue(len(inputs_image["pixel_values"]) == 1) - self.assertIsInstance(inputs_image["pixel_values"][0], list) - self.assertTrue(len(inputs_image["pixel_values"][0]) == 2) - self.assertIsInstance(inputs_image["pixel_values"][0][0], torch.Tensor) + self.assertIsInstance(inputs_image["pixel_values"], torch.Tensor) + self.assertTrue(inputs_image["pixel_values"].shape == torch.Size([2, 3, 32, 32])) # fmt: off input_ids = inputs_image["input_ids"] @@ -179,11 +185,9 @@ def test_processor_with_multiple_images_single_list(self): self.assertIn("input_ids", inputs_url) self.assertTrue(len(inputs_url["input_ids"]) == 1) self.assertIsInstance(inputs_url["input_ids"], torch.Tensor) - self.assertIsInstance(inputs_url["pixel_values"], list) - self.assertTrue(len(inputs_url["pixel_values"]) == 1) - self.assertIsInstance(inputs_url["pixel_values"][0], list) - self.assertTrue(len(inputs_url["pixel_values"][0]) == 2) - self.assertIsInstance(inputs_url["pixel_values"][0][0], torch.Tensor) + self.assertIsInstance(inputs_image["pixel_values"], torch.Tensor) + self.assertTrue(inputs_image["pixel_values"].shape == torch.Size([2, 3, 32, 32])) + # fmt: off input_ids = inputs_url["input_ids"] self.assertEqual( @@ -193,6 +197,17 @@ def test_processor_with_multiple_images_single_list(self): ) # fmt: on + # Test passing in as a nested list + inputs_url = processor(text=prompt_string, images=[[self.image_0, self.image_1]], return_tensors="pt") + self.assertTrue(inputs_image["pixel_values"].shape == torch.Size([2, 3, 32, 32])) + + # fmt: off + self.assertEqual( + inputs_url["input_ids"][0].tolist(), + [21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 6592, 2396, 2576, 2295, 8061, 1063, 1349, 4290, 16002, 41150, 1058] + ) + # fmt: on + def test_processor_with_multiple_images_multiple_lists(self): processor = self.processor_class.from_pretrained(self.tmpdirname) prompt_string = [ @@ -211,11 +226,8 @@ def test_processor_with_multiple_images_multiple_lists(self): self.assertIn("input_ids", inputs_image) self.assertTrue(len(inputs_image["input_ids"]) == 2) self.assertIsInstance(inputs_image["input_ids"], torch.Tensor) - self.assertIsInstance(inputs_image["pixel_values"], list) - self.assertTrue(len(inputs_image["pixel_values"]) == 2) - self.assertIsInstance(inputs_image["pixel_values"][0], list) - self.assertTrue(len(inputs_image["pixel_values"][0]) == 2) - self.assertIsInstance(inputs_image["pixel_values"][0][0], torch.Tensor) + self.assertIsInstance(inputs_image["pixel_values"], torch.Tensor) + self.assertTrue(inputs_image["pixel_values"].shape == torch.Size([3, 3, 32, 32])) # fmt: off input_ids = inputs_image["input_ids"] @@ -231,11 +243,8 @@ def test_processor_with_multiple_images_multiple_lists(self): self.assertIn("input_ids", inputs_url) self.assertTrue(len(inputs_url["input_ids"]) == 2) self.assertIsInstance(inputs_url["input_ids"], torch.Tensor) - self.assertIsInstance(inputs_url["pixel_values"], list) - self.assertTrue(len(inputs_url["pixel_values"]) == 2) - self.assertIsInstance(inputs_url["pixel_values"][0], list) - self.assertTrue(len(inputs_url["pixel_values"][0]) == 2) - self.assertIsInstance(inputs_url["pixel_values"][0][0], torch.Tensor) + self.assertIsInstance(inputs_image["pixel_values"], torch.Tensor) + self.assertTrue(inputs_image["pixel_values"].shape == torch.Size([3, 3, 32, 32])) # fmt: off input_ids = inputs_url["input_ids"] @@ -246,6 +255,19 @@ def test_processor_with_multiple_images_multiple_lists(self): ) # fmt: on + # Test passing as a single flat list + inputs_image = processor( + text=prompt_string, images=[self.image_0, self.image_1, self.image_2], return_tensors="pt", padding=True + ) + self.assertTrue(inputs_image["pixel_values"].shape == torch.Size([3, 3, 32, 32])) + + # fmt: off + self.assertEqual( + inputs_image["input_ids"][0].tolist(), + [21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 6592, 2396, 2576, 2295, 8061, 1063, 1349, 4290, 16002, 41150, 1058] + ) + # fmt: on + def test_processor_returns_full_length_batches(self): # to avoid https://github.com/huggingface/transformers/issues/34204 processor = self.processor_class.from_pretrained(self.tmpdirname) @@ -264,13 +286,3 @@ def test_processor_returns_full_length_batches(self): self.assertIn("input_ids", inputs_image) self.assertTrue(len(inputs_image["input_ids"]) == 5) self.assertTrue(len(inputs_image["pixel_values"]) == 5) - - # Override as PixtralProcessor needs nested images to work properly with batched inputs - @require_vision - def prepare_image_inputs(self, batch_size: Optional[int] = None): - """This function prepares a list of PIL images for testing""" - if batch_size is None: - return super().prepare_image_inputs() - if batch_size < 1: - raise ValueError("batch_size must be greater than 0") - return [[super().prepare_image_inputs()]] * batch_size diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 3f00b4e15d4b..ba996a966cc8 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -2991,6 +2991,10 @@ def test_inputs_embeds(self): model.to(torch_device) model.eval() + model_forward_args = inspect.signature(model.forward).parameters + if "inputs_embeds" not in model_forward_args: + self.skipTest(reason="This model doesn't use `inputs_embeds`") + inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class)) if not self.is_encoder_decoder: