diff --git a/src/transformers/models/dpt/image_processing_dpt.py b/src/transformers/models/dpt/image_processing_dpt.py index c49dfcfef890..3c2162409c57 100644 --- a/src/transformers/models/dpt/image_processing_dpt.py +++ b/src/transformers/models/dpt/image_processing_dpt.py @@ -139,6 +139,11 @@ class DPTImageProcessor(BaseImageProcessor): size_divisor (`int`, *optional*): If `do_pad` is `True`, pads the image dimensions to be divisible by this value. This was introduced in the DINOv2 paper, which uses the model in combination with DPT. + do_reduce_labels (`bool`, *optional*, defaults to `False`): + Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0 is + used for background, and background itself is not included in all classes of a dataset (e.g. ADE20k). The + background label will be replaced by 255. Can be overridden by the `do_reduce_labels` parameter in the + `preprocess` method. """ model_input_names = ["pixel_values"] @@ -157,6 +162,7 @@ def __init__( image_std: Optional[Union[float, List[float]]] = None, do_pad: bool = False, size_divisor: int = None, + do_reduce_labels: bool = False, **kwargs, ) -> None: super().__init__(**kwargs) @@ -174,6 +180,7 @@ def __init__( self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD self.do_pad = do_pad self.size_divisor = size_divisor + self.do_reduce_labels = do_reduce_labels def resize( self, @@ -275,10 +282,160 @@ def _get_pad(size, size_divisor): return pad(image, ((pad_size_left, pad_size_right), (pad_size_top, pad_size_bottom)), data_format=data_format) + # Copied from transformers.models.beit.image_processing_beit.BeitImageProcessor.reduce_label + def reduce_label(self, label: ImageInput) -> np.ndarray: + label = to_numpy_array(label) + # Avoid using underflow conversion + label[label == 0] = 255 + label = label - 1 + label[label == 254] = 255 + return label + + def _preprocess( + self, + image: ImageInput, + do_reduce_labels: bool = None, + do_resize: bool = None, + size: Dict[str, int] = None, + resample: PILImageResampling = None, + keep_aspect_ratio: bool = None, + ensure_multiple_of: int = None, + do_rescale: bool = None, + rescale_factor: float = None, + do_normalize: bool = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_pad: bool = None, + size_divisor: int = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ): + if do_reduce_labels: + image = self.reduce_label(image) + + if do_resize: + image = self.resize( + image=image, + size=size, + resample=resample, + keep_aspect_ratio=keep_aspect_ratio, + ensure_multiple_of=ensure_multiple_of, + input_data_format=input_data_format, + ) + + if do_rescale: + image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) + + if do_normalize: + image = self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format) + + if do_pad: + image = self.pad_image(image=image, size_divisor=size_divisor, input_data_format=input_data_format) + + return image + + def _preprocess_image( + self, + image: ImageInput, + do_resize: bool = None, + size: Dict[str, int] = None, + resample: PILImageResampling = None, + keep_aspect_ratio: bool = None, + ensure_multiple_of: int = None, + do_rescale: bool = None, + rescale_factor: float = None, + do_normalize: bool = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_pad: bool = None, + size_divisor: int = None, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> np.ndarray: + """Preprocesses a single image.""" + # All transformations expect numpy arrays. + image = to_numpy_array(image) + if do_rescale and is_scaled_image(image): + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(image) + + image = self._preprocess( + image, + do_reduce_labels=False, + do_resize=do_resize, + size=size, + resample=resample, + keep_aspect_ratio=keep_aspect_ratio, + ensure_multiple_of=ensure_multiple_of, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_pad=do_pad, + size_divisor=size_divisor, + input_data_format=input_data_format, + ) + if data_format is not None: + image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) + return image + + def _preprocess_segmentation_map( + self, + segmentation_map: ImageInput, + do_resize: bool = None, + size: Dict[str, int] = None, + resample: PILImageResampling = None, + keep_aspect_ratio: bool = None, + ensure_multiple_of: int = None, + do_reduce_labels: bool = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ): + """Preprocesses a single segmentation map.""" + # All transformations expect numpy arrays. + segmentation_map = to_numpy_array(segmentation_map) + # Add an axis to the segmentation maps for transformations. + if segmentation_map.ndim == 2: + segmentation_map = segmentation_map[None, ...] + added_dimension = True + input_data_format = ChannelDimension.FIRST + else: + added_dimension = False + if input_data_format is None: + input_data_format = infer_channel_dimension_format(segmentation_map, num_channels=1) + segmentation_map = self._preprocess( + image=segmentation_map, + do_reduce_labels=do_reduce_labels, + do_resize=do_resize, + size=size, + resample=resample, + keep_aspect_ratio=keep_aspect_ratio, + ensure_multiple_of=ensure_multiple_of, + do_normalize=False, + do_rescale=False, + input_data_format=input_data_format, + ) + # Remove extra axis if added + if added_dimension: + segmentation_map = np.squeeze(segmentation_map, axis=0) + segmentation_map = segmentation_map.astype(np.int64) + return segmentation_map + + # Copied from transformers.models.beit.image_processing_beit.BeitImageProcessor.__call__ + def __call__(self, images, segmentation_maps=None, **kwargs): + # Overrides the `__call__` method of the `Preprocessor` class such that the images and segmentation maps can both + # be passed in as positional arguments. + return super().__call__(images, segmentation_maps=segmentation_maps, **kwargs) + @filter_out_non_signature_kwargs() def preprocess( self, images: ImageInput, + segmentation_maps: Optional[ImageInput] = None, do_resize: bool = None, size: int = None, keep_aspect_ratio: bool = None, @@ -291,6 +448,7 @@ def preprocess( image_std: Optional[Union[float, List[float]]] = None, do_pad: bool = None, size_divisor: int = None, + do_reduce_labels: Optional[bool] = None, return_tensors: Optional[Union[str, TensorType]] = None, data_format: ChannelDimension = ChannelDimension.FIRST, input_data_format: Optional[Union[str, ChannelDimension]] = None, @@ -302,6 +460,8 @@ def preprocess( images (`ImageInput`): Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If passing in images with pixel values between 0 and 1, set `do_rescale=False`. + segmentation_maps (`ImageInput`, *optional*): + Segmentation map to preprocess. do_resize (`bool`, *optional*, defaults to `self.do_resize`): Whether to resize the image. size (`Dict[str, int]`, *optional*, defaults to `self.size`): @@ -326,6 +486,10 @@ def preprocess( Image mean. image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): Image standard deviation. + do_reduce_labels (`bool`, *optional*, defaults to `self.do_reduce_labels`): + Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0 + is used for background, and background itself is not included in all classes of a dataset (e.g. + ADE20k). The background label will be replaced by 255. return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. @@ -357,9 +521,13 @@ def preprocess( image_std = image_std if image_std is not None else self.image_std do_pad = do_pad if do_pad is not None else self.do_pad size_divisor = size_divisor if size_divisor is not None else self.size_divisor + do_reduce_labels = do_reduce_labels if do_reduce_labels is not None else self.do_reduce_labels images = make_list_of_images(images) + if segmentation_maps is not None: + segmentation_maps = make_list_of_images(segmentation_maps, expected_ndims=2) + if not valid_images(images): raise ValueError( "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " @@ -377,55 +545,47 @@ def preprocess( size=size, resample=resample, ) - # All transformations expect numpy arrays. - images = [to_numpy_array(image) for image in images] - if do_rescale and is_scaled_image(images[0]): - logger.warning_once( - "It looks like you are trying to rescale already rescaled images. If the input" - " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + images = [ + self._preprocess_image( + image=img, + do_resize=do_resize, + do_rescale=do_rescale, + do_normalize=do_normalize, + do_pad=do_pad, + size=size, + resample=resample, + keep_aspect_ratio=keep_aspect_ratio, + ensure_multiple_of=ensure_multiple_of, + rescale_factor=rescale_factor, + image_mean=image_mean, + image_std=image_std, + size_divisor=size_divisor, + data_format=data_format, + input_data_format=input_data_format, ) + for img in images + ] - if input_data_format is None: - # We assume that all images have the same channel dimension format. - input_data_format = infer_channel_dimension_format(images[0]) + data = {"pixel_values": images} - if do_resize: - images = [ - self.resize( - image=image, + if segmentation_maps is not None: + segmentation_maps = [ + self._preprocess_segmentation_map( + segmentation_map=segmentation_map, + do_reduce_labels=do_reduce_labels, + do_resize=do_resize, size=size, resample=resample, keep_aspect_ratio=keep_aspect_ratio, ensure_multiple_of=ensure_multiple_of, input_data_format=input_data_format, ) - for image in images + for segmentation_map in segmentation_maps ] - if do_rescale: - images = [ - self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) - for image in images - ] + data["labels"] = segmentation_maps - if do_normalize: - images = [ - self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format) - for image in images - ] - - if do_pad: - images = [ - self.pad_image(image=image, size_divisor=size_divisor, input_data_format=input_data_format) - for image in images - ] - - images = [ - to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images - ] - - data = {"pixel_values": images} return BatchFeature(data=data, tensor_type=return_tensors) # Copied from transformers.models.beit.image_processing_beit.BeitImageProcessor.post_process_semantic_segmentation with Beit->DPT diff --git a/tests/models/dpt/test_image_processing_dpt.py b/tests/models/dpt/test_image_processing_dpt.py index f68e9bb6130a..713c722a4c2b 100644 --- a/tests/models/dpt/test_image_processing_dpt.py +++ b/tests/models/dpt/test_image_processing_dpt.py @@ -17,14 +17,20 @@ import unittest import numpy as np +from datasets import load_dataset -from transformers.file_utils import is_vision_available +from transformers.file_utils import is_torch_available, is_vision_available from transformers.testing_utils import require_torch, require_vision from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs +if is_torch_available(): + import torch + if is_vision_available(): + from PIL import Image + from transformers import DPTImageProcessor @@ -42,6 +48,7 @@ def __init__( do_normalize=True, image_mean=[0.5, 0.5, 0.5], image_std=[0.5, 0.5, 0.5], + do_reduce_labels=False, ): super().__init__() size = size if size is not None else {"height": 18, "width": 18} @@ -56,6 +63,7 @@ def __init__( self.do_normalize = do_normalize self.image_mean = image_mean self.image_std = image_std + self.do_reduce_labels = do_reduce_labels def prepare_image_processor_dict(self): return { @@ -64,6 +72,7 @@ def prepare_image_processor_dict(self): "do_normalize": self.do_normalize, "do_resize": self.do_resize, "size": self.size, + "do_reduce_labels": self.do_reduce_labels, } def expected_output_image_shape(self, images): @@ -81,6 +90,28 @@ def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=F ) +# Copied from transformers.tests.models.beit.test_image_processing_beit.prepare_semantic_single_inputs +def prepare_semantic_single_inputs(): + dataset = load_dataset("hf-internal-testing/fixtures_ade20k", split="test", trust_remote_code=True) + + image = Image.open(dataset[0]["file"]) + map = Image.open(dataset[1]["file"]) + + return image, map + + +# Copied from transformers.tests.models.beit.test_image_processing_beit.prepare_semantic_batch_inputs +def prepare_semantic_batch_inputs(): + ds = load_dataset("hf-internal-testing/fixtures_ade20k", split="test", trust_remote_code=True) + + image1 = Image.open(ds[0]["file"]) + map1 = Image.open(ds[1]["file"]) + image2 = Image.open(ds[2]["file"]) + map2 = Image.open(ds[3]["file"]) + + return [image1, image2], [map1, map2] + + @require_torch @require_vision class DPTImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): @@ -105,6 +136,7 @@ def test_image_processor_properties(self): self.assertTrue(hasattr(image_processing, "rescale_factor")) self.assertTrue(hasattr(image_processing, "do_pad")) self.assertTrue(hasattr(image_processing, "size_divisor")) + self.assertTrue(hasattr(image_processing, "do_reduce_labels")) def test_image_processor_from_dict_with_kwargs(self): image_processor = self.image_processing_class.from_dict(self.image_processor_dict) @@ -138,3 +170,126 @@ def test_keep_aspect_ratio(self): pixel_values = image_processor(image, return_tensors="pt").pixel_values self.assertEqual(list(pixel_values.shape), [1, 3, 512, 672]) + + # Copied from transformers.tests.models.beit.test_image_processing_beit.BeitImageProcessingTest.test_call_segmentation_maps + def test_call_segmentation_maps(self): + # Initialize image_processor + image_processor = self.image_processing_class(**self.image_processor_dict) + # create random PyTorch tensors + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True) + maps = [] + for image in image_inputs: + self.assertIsInstance(image, torch.Tensor) + maps.append(torch.zeros(image.shape[-2:]).long()) + + # Test not batched input + encoding = image_processor(image_inputs[0], maps[0], return_tensors="pt") + self.assertEqual( + encoding["pixel_values"].shape, + ( + 1, + self.image_processor_tester.num_channels, + self.image_processor_tester.size["height"], + self.image_processor_tester.size["width"], + ), + ) + self.assertEqual( + encoding["labels"].shape, + ( + 1, + self.image_processor_tester.size["height"], + self.image_processor_tester.size["width"], + ), + ) + self.assertEqual(encoding["labels"].dtype, torch.long) + self.assertTrue(encoding["labels"].min().item() >= 0) + self.assertTrue(encoding["labels"].max().item() <= 255) + + # Test batched + encoding = image_processor(image_inputs, maps, return_tensors="pt") + self.assertEqual( + encoding["pixel_values"].shape, + ( + self.image_processor_tester.batch_size, + self.image_processor_tester.num_channels, + self.image_processor_tester.size["height"], + self.image_processor_tester.size["width"], + ), + ) + self.assertEqual( + encoding["labels"].shape, + ( + self.image_processor_tester.batch_size, + self.image_processor_tester.size["height"], + self.image_processor_tester.size["width"], + ), + ) + self.assertEqual(encoding["labels"].dtype, torch.long) + self.assertTrue(encoding["labels"].min().item() >= 0) + self.assertTrue(encoding["labels"].max().item() <= 255) + + # Test not batched input (PIL images) + image, segmentation_map = prepare_semantic_single_inputs() + + encoding = image_processor(image, segmentation_map, return_tensors="pt") + self.assertEqual( + encoding["pixel_values"].shape, + ( + 1, + self.image_processor_tester.num_channels, + self.image_processor_tester.size["height"], + self.image_processor_tester.size["width"], + ), + ) + self.assertEqual( + encoding["labels"].shape, + ( + 1, + self.image_processor_tester.size["height"], + self.image_processor_tester.size["width"], + ), + ) + self.assertEqual(encoding["labels"].dtype, torch.long) + self.assertTrue(encoding["labels"].min().item() >= 0) + self.assertTrue(encoding["labels"].max().item() <= 255) + + # Test batched input (PIL images) + images, segmentation_maps = prepare_semantic_batch_inputs() + + encoding = image_processor(images, segmentation_maps, return_tensors="pt") + self.assertEqual( + encoding["pixel_values"].shape, + ( + 2, + self.image_processor_tester.num_channels, + self.image_processor_tester.size["height"], + self.image_processor_tester.size["width"], + ), + ) + self.assertEqual( + encoding["labels"].shape, + ( + 2, + self.image_processor_tester.size["height"], + self.image_processor_tester.size["width"], + ), + ) + self.assertEqual(encoding["labels"].dtype, torch.long) + self.assertTrue(encoding["labels"].min().item() >= 0) + self.assertTrue(encoding["labels"].max().item() <= 255) + + # Copied from transformers.tests.models.beit.test_image_processing_beit.BeitImageProcessingTest.test_reduce_labels + def test_reduce_labels(self): + # Initialize image_processor + image_processor = self.image_processing_class(**self.image_processor_dict) + + # ADE20k has 150 classes, and the background is included, so labels should be between 0 and 150 + image, map = prepare_semantic_single_inputs() + encoding = image_processor(image, map, return_tensors="pt") + self.assertTrue(encoding["labels"].min().item() >= 0) + self.assertTrue(encoding["labels"].max().item() <= 150) + + image_processor.do_reduce_labels = True + encoding = image_processor(image, map, return_tensors="pt") + self.assertTrue(encoding["labels"].min().item() >= 0) + self.assertTrue(encoding["labels"].max().item() <= 255)