diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 3cae407a70a..2aa1fc5ba1e 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -33,20 +33,33 @@ tasks (image classification, detection, segmentation, video classification). from torchvision import tv_tensors img = torch.randint(0, 256, size=(3, H, W), dtype=torch.uint8) - bboxes = torch.randint(0, H // 2, size=(3, 4)) - bboxes[:, 2:] += bboxes[:, :2] - bboxes = tv_tensors.BoundingBoxes(bboxes, format="XYXY", canvas_size=(H, W)) + boxes = torch.randint(0, H // 2, size=(3, 4)) + boxes[:, 2:] += boxes[:, :2] + boxes = tv_tensors.BoundingBoxes(boxes, format="XYXY", canvas_size=(H, W)) # The same transforms can be used! - img, bboxes = transforms(img, bboxes) + img, boxes = transforms(img, boxes) # And you can pass arbitrary input structures - output_dict = transforms({"image": img, "bboxes": bboxes}) + output_dict = transforms({"image": img, "boxes": boxes}) Transforms are typically passed as the ``transform`` or ``transforms`` argument to the :ref:`Datasets `. -.. TODO: Reader guide, i.e. what to read depending on what you're looking for -.. TODO: add link to getting started guide here. +Start here +---------- + +Whether you're new to Torchvision transforms, or you're already experienced with +them, we encourage you to start with +:ref:`sphx_glr_auto_examples_transforms_plot_transforms_getting_started.py` in +order to learn more about what can be done with the new v2 transforms. + +Then, browse the sections in below this page for general information and +performance tips. The available transforms and functionals are listed in the +:ref:`API reference `. + +More information and tutorials can also be found in our :ref:`example gallery +`, e.g. :ref:`sphx_glr_auto_examples_transforms_plot_transforms_e2e.py` +or :ref:`sphx_glr_auto_examples_transforms_plot_custom_transforms.py`. .. _conventions: @@ -98,25 +111,21 @@ advantages compared to the v1 ones (in ``torchvision.transforms``): - They can transform images **but also** bounding boxes, masks, or videos. This provides support for tasks beyond image classification: detection, segmentation, - video classification, etc. + video classification, etc. See + :ref:`sphx_glr_auto_examples_transforms_plot_transforms_getting_started.py` + and :ref:`sphx_glr_auto_examples_transforms_plot_transforms_e2e.py`. - They support more transforms like :class:`~torchvision.transforms.v2.CutMix` - and :class:`~torchvision.transforms.v2.MixUp`. + and :class:`~torchvision.transforms.v2.MixUp`. See + :ref:`sphx_glr_auto_examples_transforms_plot_cutmix_mixup.py`. - They're :ref:`faster `. - They support arbitrary input structures (dicts, lists, tuples, etc.). - Future improvements and features will be added to the v2 transforms only. -.. TODO: Add link to e2e example for first bullet point. - These transforms are **fully backward compatible** with the v1 ones, so if you're already using tranforms from ``torchvision.transforms``, all you need to do to is to update the import to ``torchvision.transforms.v2``. In terms of output, there might be negligible differences due to implementation differences. -To learn more about the v2 transforms, check out -:ref:`sphx_glr_auto_examples_transforms_plot_transforms_getting_started.py`. - -.. TODO: make sure link is still good!! - .. note:: The v2 transforms are still BETA, but at this point we do not expect @@ -184,7 +193,7 @@ This is very much like the :mod:`torch.nn` package which defines both classes and functional equivalents in :mod:`torch.nn.functional`. The functionals support PIL images, pure tensors, or :ref:`TVTensors -`, e.g. both ``resize(image_tensor)`` and ``resize(bboxes)`` are +`, e.g. both ``resize(image_tensor)`` and ``resize(boxes)`` are valid. .. note:: @@ -248,6 +257,8 @@ be derived from ``torch.nn.Module``. See also: :ref:`sphx_glr_auto_examples_others_plot_scripted_tensor_transforms.py`. +.. _v2_api_ref: + V2 API reference - Recommended ------------------------------ diff --git a/docs/source/tv_tensors.rst b/docs/source/tv_tensors.rst index e80a1ed88fb..cb8a3c45fa9 100644 --- a/docs/source/tv_tensors.rst +++ b/docs/source/tv_tensors.rst @@ -7,9 +7,13 @@ TVTensors TVTensors are :class:`torch.Tensor` subclasses which the v2 :ref:`transforms ` use under the hood to dispatch their inputs to the appropriate -lower-level kernels. Most users do not need to manipulate TVTensors directly and -can simply rely on dataset wrapping - see e.g. -:ref:`sphx_glr_auto_examples_transforms_plot_transforms_e2e.py`. +lower-level kernels. Most users do not need to manipulate TVTensors directly. + +Refer to +:ref:`sphx_glr_auto_examples_transforms_plot_transforms_getting_started.py` for +an introduction to TVTensors, or +:ref:`sphx_glr_auto_examples_transforms_plot_tv_tensors.py` for more advanced +info. .. autosummary:: :toctree: generated/ diff --git a/gallery/README.rst b/gallery/README.rst index 9a0838f493f..8dfea355276 100644 --- a/gallery/README.rst +++ b/gallery/README.rst @@ -1,2 +1,4 @@ +.. _gallery: + Examples and tutorials ====================== diff --git a/gallery/transforms/plot_transforms_e2e.py b/gallery/transforms/plot_transforms_e2e.py index 66d9203d70c..6c58b4a5a9a 100644 --- a/gallery/transforms/plot_transforms_e2e.py +++ b/gallery/transforms/plot_transforms_e2e.py @@ -166,3 +166,16 @@ print(f"{[type(target) for target in targets] = }") for name, loss_val in loss_dict.items(): print(f"{name:<20}{loss_val:.3f}") + +# %% +# Training References +# ------------------- +# +# From there, you can check out the `torchvision references +# `_ where you'll find +# the actual training scripts we use to train our models. +# +# **Disclaimer** The code in our references is more complex than what you'll +# need for your own use-cases: this is because we're supporting different +# backends (PIL, tensors, TVTensors) and different transforms namespaces (v1 and +# v2). So don't be afraid to simplify and only keep what you need. diff --git a/gallery/transforms/plot_transforms_getting_started.py b/gallery/transforms/plot_transforms_getting_started.py index cbaab3dc97d..c61d1cc1be0 100644 --- a/gallery/transforms/plot_transforms_getting_started.py +++ b/gallery/transforms/plot_transforms_getting_started.py @@ -217,6 +217,8 @@ # can still be transformed by some transforms like # :class:`~torchvision.transforms.v2.SanitizeBoundingBoxes`!). # +# .. _transforms_datasets_intercompatibility: +# # Transforms and Datasets intercompatibility # ------------------------------------------ # diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 834666be36b..175a3ac161c 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -449,68 +449,6 @@ def test__get_params(self, fill, side_range): assert 0 <= params["padding"][3] <= (side_range[1] - 1) * h -class TestRandomCrop: - def test_assertions(self): - with pytest.raises(ValueError, match="Please provide only two dimensions"): - transforms.RandomCrop([10, 12, 14]) - - with pytest.raises(TypeError, match="Got inappropriate padding arg"): - transforms.RandomCrop([10, 12], padding="abc") - - with pytest.raises(ValueError, match="Padding must be an int or a 1, 2, or 4"): - transforms.RandomCrop([10, 12], padding=[-0.7, 0, 0.7]) - - with pytest.raises(TypeError, match="Got inappropriate fill arg"): - transforms.RandomCrop([10, 12], padding=1, fill="abc") - - with pytest.raises(ValueError, match="Padding mode should be either"): - transforms.RandomCrop([10, 12], padding=1, padding_mode="abc") - - @pytest.mark.parametrize("padding", [None, 1, [2, 3], [1, 2, 3, 4]]) - @pytest.mark.parametrize("size, pad_if_needed", [((10, 10), False), ((50, 25), True)]) - def test__get_params(self, padding, pad_if_needed, size): - h, w = size = (24, 32) - image = make_image(size) - - transform = transforms.RandomCrop(size, padding=padding, pad_if_needed=pad_if_needed) - params = transform._get_params([image]) - - if padding is not None: - if isinstance(padding, int): - pad_top = pad_bottom = pad_left = pad_right = padding - elif isinstance(padding, list) and len(padding) == 2: - pad_left = pad_right = padding[0] - pad_top = pad_bottom = padding[1] - elif isinstance(padding, list) and len(padding) == 4: - pad_left, pad_top, pad_right, pad_bottom = padding - - h += pad_top + pad_bottom - w += pad_left + pad_right - else: - pad_left = pad_right = pad_top = pad_bottom = 0 - - if pad_if_needed: - if w < size[1]: - diff = size[1] - w - pad_left += diff - pad_right += diff - w += 2 * diff - if h < size[0]: - diff = size[0] - h - pad_top += diff - pad_bottom += diff - h += 2 * diff - - padding = [pad_left, pad_top, pad_right, pad_bottom] - - assert 0 <= params["top"] <= h - size[0] + 1 - assert 0 <= params["left"] <= w - size[1] + 1 - assert params["height"] == size[0] - assert params["width"] == size[1] - assert params["needs_pad"] is any(padding) - assert params["padding"] == padding - - class TestGaussianBlur: def test_assertions(self): with pytest.raises(ValueError, match="Kernel size should be a tuple/list of two integers"): diff --git a/test/test_transforms_v2_consistency.py b/test/test_transforms_v2_consistency.py index ffd435aef32..1f96caa247f 100644 --- a/test/test_transforms_v2_consistency.py +++ b/test/test_transforms_v2_consistency.py @@ -304,26 +304,6 @@ def __init__( ], closeness_kwargs={"rtol": 1e-5, "atol": 1e-5}, ), - ConsistencyConfig( - v2_transforms.RandomCrop, - legacy_transforms.RandomCrop, - [ - ArgsKwargs(12), - ArgsKwargs((15, 17)), - NotScriptableArgsKwargs(11, padding=1), - ArgsKwargs(11, padding=[1]), - ArgsKwargs((8, 13), padding=(2, 3)), - ArgsKwargs((14, 9), padding=(0, 2, 1, 0)), - ArgsKwargs(36, pad_if_needed=True), - ArgsKwargs((7, 8), fill=1), - NotScriptableArgsKwargs(5, fill=(1, 2, 3)), - ArgsKwargs(12), - NotScriptableArgsKwargs(15, padding=2, padding_mode="edge"), - ArgsKwargs(17, padding=(1, 0), padding_mode="reflect"), - ArgsKwargs(8, padding=(3, 0, 0, 1), padding_mode="symmetric"), - ], - make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(26, 26), (18, 33), (29, 22)]), - ), ConsistencyConfig( v2_transforms.RandomPerspective, legacy_transforms.RandomPerspective, @@ -558,7 +538,6 @@ def test_call_consistency(config, args_kwargs): (v2_transforms.RandomResizedCrop, ArgsKwargs(make_image(), scale=[0.3, 0.7], ratio=[0.5, 1.5])), (v2_transforms.ColorJitter, ArgsKwargs(brightness=None, contrast=None, saturation=None, hue=None)), (v2_transforms.GaussianBlur, ArgsKwargs(0.3, 1.4)), - (v2_transforms.RandomCrop, ArgsKwargs(make_image(size=(61, 47)), output_size=(19, 25))), (v2_transforms.RandomPerspective, ArgsKwargs(23, 17, 0.5)), (v2_transforms.AutoAugment, ArgsKwargs(5)), ] diff --git a/test/test_transforms_v2_functional.py b/test/test_transforms_v2_functional.py index cdd75ca0fbf..23f06475cf1 100644 --- a/test/test_transforms_v2_functional.py +++ b/test/test_transforms_v2_functional.py @@ -576,63 +576,6 @@ def _compute_affine_matrix(angle_, translate_, scale_, shear_, center_): return true_matrix -@pytest.mark.parametrize("device", cpu_and_cuda()) -@pytest.mark.parametrize( - "format", - [tv_tensors.BoundingBoxFormat.XYXY, tv_tensors.BoundingBoxFormat.XYWH, tv_tensors.BoundingBoxFormat.CXCYWH], -) -@pytest.mark.parametrize( - "top, left, height, width, expected_bboxes", - [ - [8, 12, 30, 40, [(-2.0, 7.0, 13.0, 27.0), (38.0, -3.0, 58.0, 14.0), (33.0, 38.0, 44.0, 54.0)]], - [-8, 12, 70, 40, [(-2.0, 23.0, 13.0, 43.0), (38.0, 13.0, 58.0, 30.0), (33.0, 54.0, 44.0, 70.0)]], - ], -) -def test_correctness_crop_bounding_boxes(device, format, top, left, height, width, expected_bboxes): - - # Expected bboxes computed using Albumentations: - # import numpy as np - # from albumentations.augmentations.crops.functional import crop_bbox_by_coords, normalize_bbox, denormalize_bbox - # expected_bboxes = [] - # for in_box in in_boxes: - # n_in_box = normalize_bbox(in_box, *size) - # n_out_box = crop_bbox_by_coords( - # n_in_box, (left, top, left + width, top + height), height, width, *size - # ) - # out_box = denormalize_bbox(n_out_box, height, width) - # expected_bboxes.append(out_box) - - format = tv_tensors.BoundingBoxFormat.XYXY - canvas_size = (64, 76) - in_boxes = [ - [10.0, 15.0, 25.0, 35.0], - [50.0, 5.0, 70.0, 22.0], - [45.0, 46.0, 56.0, 62.0], - ] - in_boxes = torch.tensor(in_boxes, device=device) - if format != tv_tensors.BoundingBoxFormat.XYXY: - in_boxes = convert_bounding_box_format(in_boxes, tv_tensors.BoundingBoxFormat.XYXY, format) - - expected_bboxes = clamp_bounding_boxes( - tv_tensors.BoundingBoxes(expected_bboxes, format="XYXY", canvas_size=canvas_size) - ).tolist() - - output_boxes, output_canvas_size = F.crop_bounding_boxes( - in_boxes, - format, - top, - left, - canvas_size[0], - canvas_size[1], - ) - - if format != tv_tensors.BoundingBoxFormat.XYXY: - output_boxes = convert_bounding_box_format(output_boxes, format, tv_tensors.BoundingBoxFormat.XYXY) - - torch.testing.assert_close(output_boxes.tolist(), expected_bboxes) - torch.testing.assert_close(output_canvas_size, canvas_size) - - @pytest.mark.parametrize("device", cpu_and_cuda()) def test_correctness_vertical_flip_segmentation_mask_on_fixed_input(device): mask = torch.zeros((3, 3, 3), dtype=torch.long, device=device) diff --git a/test/test_transforms_v2_refactored.py b/test/test_transforms_v2_refactored.py index 55d5152041b..ad5cd8e00d8 100644 --- a/test/test_transforms_v2_refactored.py +++ b/test/test_transforms_v2_refactored.py @@ -228,7 +228,7 @@ def check_functional_kernel_signature_match(functional, *, kernel, input_type): assert functional_param == kernel_param -def _check_transform_v1_compatibility(transform, input, rtol, atol): +def _check_transform_v1_compatibility(transform, input, *, rtol, atol): """If the transform defines the ``_v1_transform_cls`` attribute, checks if the transform has a public, static ``get_params`` method that is the v1 equivalent, the output is close to v1, is scriptable, and the scripted version can be called without error.""" @@ -357,10 +357,11 @@ def reference_affine_bounding_boxes_helper(bounding_boxes, *, affine_matrix, new def affine_bounding_boxes(bounding_boxes): dtype = bounding_boxes.dtype + device = bounding_boxes.device # Go to float before converting to prevent precision loss in case of CXCYWH -> XYXY and W or H is 1 input_xyxy = F.convert_bounding_box_format( - bounding_boxes.to(torch.float64, copy=True), + bounding_boxes.to(dtype=torch.float64, device="cpu", copy=True), old_format=format, new_format=tv_tensors.BoundingBoxFormat.XYXY, inplace=True, @@ -396,9 +397,13 @@ def affine_bounding_boxes(bounding_boxes): output, format=format, canvas_size=canvas_size, - ).to(dtype) + ) + else: + # We leave the bounding box as float64 so the caller gets the full precision to perform any additional + # operation + dtype = output.dtype - return output + return output.to(dtype=dtype, device=device) return tv_tensors.BoundingBoxes( torch.cat([affine_bounding_boxes(b) for b in bounding_boxes.reshape(-1, 4).unbind()], dim=0).reshape( @@ -459,7 +464,7 @@ def _compute_output_size(self, *, input_size, size, max_size): @pytest.mark.parametrize("antialias", [True, False]) @pytest.mark.parametrize("dtype", [torch.float32, torch.uint8]) @pytest.mark.parametrize("device", cpu_and_cuda()) - def test_kernel_image_tensor(self, size, interpolation, use_max_size, antialias, dtype, device): + def test_kernel_image(self, size, interpolation, use_max_size, antialias, dtype, device): if not (max_size_kwarg := self._make_max_size_kwarg(use_max_size=use_max_size, size=size)): return @@ -821,7 +826,7 @@ def test_float16_no_rounding(self): # Non-regression test for https://github.com/pytorch/vision/issues/7667 input = make_image_tensor(self.INPUT_SIZE, dtype=torch.float16) - output = F.resize_image(input, size=self.OUTPUT_SIZES[0]) + output = F.resize_image(input, size=self.OUTPUT_SIZES[0], antialias=True) assert output.dtype is torch.float16 assert (output.round() - output).abs().sum() > 0 @@ -830,7 +835,7 @@ def test_float16_no_rounding(self): class TestHorizontalFlip: @pytest.mark.parametrize("dtype", [torch.float32, torch.uint8]) @pytest.mark.parametrize("device", cpu_and_cuda()) - def test_kernel_image_tensor(self, dtype, device): + def test_kernel_image(self, dtype, device): check_kernel(F.horizontal_flip_image, make_image(dtype=dtype, device=device)) @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat)) @@ -980,7 +985,7 @@ def _check_kernel(self, kernel, input, *args, **kwargs): ) @pytest.mark.parametrize("dtype", [torch.float32, torch.uint8]) @pytest.mark.parametrize("device", cpu_and_cuda()) - def test_kernel_image_tensor(self, param, value, dtype, device): + def test_kernel_image(self, param, value, dtype, device): if param == "fill": value = adapt_fill(value, dtype=dtype) self._check_kernel( @@ -1280,7 +1285,7 @@ def test_transform_unknown_fill_error(self): class TestVerticalFlip: @pytest.mark.parametrize("dtype", [torch.float32, torch.uint8]) @pytest.mark.parametrize("device", cpu_and_cuda()) - def test_kernel_image_tensor(self, dtype, device): + def test_kernel_image(self, dtype, device): check_kernel(F.vertical_flip_image, make_image(dtype=dtype, device=device)) @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat)) @@ -1375,6 +1380,7 @@ def test_transform_noop(self, make_input, device): assert_equal(output, input) +@pytest.mark.filterwarnings("ignore:The provided center argument has no effect") class TestRotate: _EXHAUSTIVE_TYPE_AFFINE_KWARGS = dict( # float, int @@ -1403,7 +1409,7 @@ class TestRotate: ) @pytest.mark.parametrize("dtype", [torch.float32, torch.uint8]) @pytest.mark.parametrize("device", cpu_and_cuda()) - def test_kernel_image_tensor(self, param, value, dtype, device): + def test_kernel_image(self, param, value, dtype, device): kwargs = {param: value} if param != "angle": kwargs["angle"] = self._MINIMAL_AFFINE_KWARGS["angle"] @@ -2381,7 +2387,7 @@ def _make_displacement(self, inpt): ) @pytest.mark.parametrize("dtype", [torch.float32, torch.uint8]) @pytest.mark.parametrize("device", cpu_and_cuda()) - def test_kernel_image_tensor(self, param, value, dtype, device): + def test_kernel_image(self, param, value, dtype, device): image = make_image_tensor(dtype=dtype, device=device) check_kernel( @@ -2487,6 +2493,233 @@ def test_correctness(self): assert isinstance(out_value, type(input_value)) +class TestCrop: + INPUT_SIZE = (21, 11) + + CORRECTNESS_CROP_KWARGS = [ + # center + dict(top=5, left=5, height=10, width=5), + # larger than input, i.e. pad + dict(top=-5, left=-5, height=30, width=20), + # sides: left, right, top, bottom + dict(top=-5, left=-5, height=30, width=10), + dict(top=-5, left=5, height=30, width=10), + dict(top=-5, left=-5, height=20, width=20), + dict(top=5, left=-5, height=20, width=20), + # corners: top-left, top-right, bottom-left, bottom-right + dict(top=-5, left=-5, height=20, width=10), + dict(top=-5, left=5, height=20, width=10), + dict(top=5, left=-5, height=20, width=10), + dict(top=5, left=5, height=20, width=10), + ] + MINIMAL_CROP_KWARGS = CORRECTNESS_CROP_KWARGS[0] + + @pytest.mark.parametrize("kwargs", CORRECTNESS_CROP_KWARGS) + @pytest.mark.parametrize("dtype", [torch.uint8, torch.float32]) + @pytest.mark.parametrize("device", cpu_and_cuda()) + def test_kernel_image(self, kwargs, dtype, device): + check_kernel(F.crop_image, make_image(self.INPUT_SIZE, dtype=dtype, device=device), **kwargs) + + @pytest.mark.parametrize("kwargs", CORRECTNESS_CROP_KWARGS) + @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat)) + @pytest.mark.parametrize("dtype", [torch.float32, torch.int64]) + @pytest.mark.parametrize("device", cpu_and_cuda()) + def test_kernel_bounding_box(self, kwargs, format, dtype, device): + bounding_boxes = make_bounding_boxes(self.INPUT_SIZE, format=format, dtype=dtype, device=device) + check_kernel(F.crop_bounding_boxes, bounding_boxes, format=format, **kwargs) + + @pytest.mark.parametrize("make_mask", [make_segmentation_mask, make_detection_mask]) + def test_kernel_mask(self, make_mask): + check_kernel(F.crop_mask, make_mask(self.INPUT_SIZE), **self.MINIMAL_CROP_KWARGS) + + def test_kernel_video(self): + check_kernel(F.crop_video, make_video(self.INPUT_SIZE), **self.MINIMAL_CROP_KWARGS) + + @pytest.mark.parametrize( + "make_input", + [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video], + ) + def test_functional(self, make_input): + check_functional(F.crop, make_input(self.INPUT_SIZE), **self.MINIMAL_CROP_KWARGS) + + @pytest.mark.parametrize( + ("kernel", "input_type"), + [ + (F.crop_image, torch.Tensor), + (F._crop_image_pil, PIL.Image.Image), + (F.crop_image, tv_tensors.Image), + (F.crop_bounding_boxes, tv_tensors.BoundingBoxes), + (F.crop_mask, tv_tensors.Mask), + (F.crop_video, tv_tensors.Video), + ], + ) + def test_functional_signature(self, kernel, input_type): + check_functional_kernel_signature_match(F.crop, kernel=kernel, input_type=input_type) + + @pytest.mark.parametrize("kwargs", CORRECTNESS_CROP_KWARGS) + def test_functional_image_correctness(self, kwargs): + image = make_image(self.INPUT_SIZE, dtype=torch.uint8, device="cpu") + + actual = F.crop(image, **kwargs) + expected = F.to_image(F.crop(F.to_pil_image(image), **kwargs)) + + assert_equal(actual, expected) + + @param_value_parametrization( + size=[(10, 5), (25, 15), (25, 5), (10, 15)], + fill=EXHAUSTIVE_TYPE_FILLS, + ) + @pytest.mark.parametrize( + "make_input", + [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video], + ) + def test_transform(self, param, value, make_input): + input = make_input(self.INPUT_SIZE) + + kwargs = {param: value} + if param == "fill": + # 1. size is required + # 2. the fill parameter only has an affect if we need padding + kwargs["size"] = [s + 4 for s in self.INPUT_SIZE] + + if isinstance(input, PIL.Image.Image) and isinstance(value, (tuple, list)) and len(value) == 1: + pytest.xfail("F._pad_image_pil does not support sequences of length 1 for fill.") + + if isinstance(input, tv_tensors.Mask) and isinstance(value, (tuple, list)): + pytest.skip("F.pad_mask doesn't support non-scalar fill.") + + check_transform( + transforms.RandomCrop(**kwargs, pad_if_needed=True), + input, + check_v1_compatibility=param != "fill" or isinstance(value, (int, float)), + ) + + @pytest.mark.parametrize("padding", [1, (1, 1), (1, 1, 1, 1)]) + def test_transform_padding(self, padding): + inpt = make_image(self.INPUT_SIZE) + + output_size = [s + 2 for s in F.get_size(inpt)] + transform = transforms.RandomCrop(output_size, padding=padding) + + output = transform(inpt) + + assert F.get_size(output) == output_size + + @pytest.mark.parametrize("padding", [None, 1, (1, 1), (1, 1, 1, 1)]) + def test_transform_insufficient_padding(self, padding): + inpt = make_image(self.INPUT_SIZE) + + output_size = [s + 3 for s in F.get_size(inpt)] + transform = transforms.RandomCrop(output_size, padding=padding) + + with pytest.raises(ValueError, match="larger than (padded )?input image size"): + transform(inpt) + + def test_transform_pad_if_needed(self): + inpt = make_image(self.INPUT_SIZE) + + output_size = [s * 2 for s in F.get_size(inpt)] + transform = transforms.RandomCrop(output_size, pad_if_needed=True) + + output = transform(inpt) + + assert F.get_size(output) == output_size + + @param_value_parametrization( + size=[(10, 5), (25, 15), (25, 5), (10, 15)], + fill=CORRECTNESS_FILLS, + padding_mode=["constant", "edge", "reflect", "symmetric"], + ) + @pytest.mark.parametrize("seed", list(range(5))) + def test_transform_image_correctness(self, param, value, seed): + kwargs = {param: value} + if param != "size": + # 1. size is required + # 2. the fill / padding_mode parameters only have an affect if we need padding + kwargs["size"] = [s + 4 for s in self.INPUT_SIZE] + if param == "fill": + kwargs["fill"] = adapt_fill(kwargs["fill"], dtype=torch.uint8) + + transform = transforms.RandomCrop(pad_if_needed=True, **kwargs) + + image = make_image(self.INPUT_SIZE) + + with freeze_rng_state(): + torch.manual_seed(seed) + actual = transform(image) + + torch.manual_seed(seed) + expected = F.to_image(transform(F.to_pil_image(image))) + + assert_equal(actual, expected) + + def _reference_crop_bounding_boxes(self, bounding_boxes, *, top, left, height, width): + affine_matrix = np.array( + [ + [1, 0, -left], + [0, 1, -top], + ], + ) + return reference_affine_bounding_boxes_helper( + bounding_boxes, affine_matrix=affine_matrix, new_canvas_size=(height, width) + ) + + @pytest.mark.parametrize("kwargs", CORRECTNESS_CROP_KWARGS) + @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat)) + @pytest.mark.parametrize("dtype", [torch.float32, torch.int64]) + @pytest.mark.parametrize("device", cpu_and_cuda()) + def test_functional_bounding_box_correctness(self, kwargs, format, dtype, device): + bounding_boxes = make_bounding_boxes(self.INPUT_SIZE, format=format, dtype=dtype, device=device) + + actual = F.crop(bounding_boxes, **kwargs) + expected = self._reference_crop_bounding_boxes(bounding_boxes, **kwargs) + + assert_equal(actual, expected, atol=1, rtol=0) + assert_equal(F.get_size(actual), F.get_size(expected)) + + @pytest.mark.parametrize("output_size", [(17, 11), (11, 17), (11, 11)]) + @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat)) + @pytest.mark.parametrize("dtype", [torch.float32, torch.int64]) + @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("seed", list(range(5))) + def test_transform_bounding_boxes_correctness(self, output_size, format, dtype, device, seed): + input_size = [s * 2 for s in output_size] + bounding_boxes = make_bounding_boxes(input_size, format=format, dtype=dtype, device=device) + + transform = transforms.RandomCrop(output_size) + + with freeze_rng_state(): + torch.manual_seed(seed) + params = transform._get_params([bounding_boxes]) + assert not params.pop("needs_pad") + del params["padding"] + assert params.pop("needs_crop") + + torch.manual_seed(seed) + actual = transform(bounding_boxes) + + expected = self._reference_crop_bounding_boxes(bounding_boxes, **params) + + assert_equal(actual, expected) + assert_equal(F.get_size(actual), F.get_size(expected)) + + def test_errors(self): + with pytest.raises(ValueError, match="Please provide only two dimensions"): + transforms.RandomCrop([10, 12, 14]) + + with pytest.raises(TypeError, match="Got inappropriate padding arg"): + transforms.RandomCrop([10, 12], padding="abc") + + with pytest.raises(ValueError, match="Padding must be an int or a 1, 2, or 4"): + transforms.RandomCrop([10, 12], padding=[-0.7, 0, 0.7]) + + with pytest.raises(TypeError, match="Got inappropriate fill arg"): + transforms.RandomCrop([10, 12], padding=1, fill="abc") + + with pytest.raises(ValueError, match="Padding mode should be either"): + transforms.RandomCrop([10, 12], padding=1, padding_mode="abc") + + class TestErase: INPUT_SIZE = (17, 11) FUNCTIONAL_KWARGS = dict( diff --git a/test/transforms_v2_dispatcher_infos.py b/test/transforms_v2_dispatcher_infos.py index 446892a812d..6d7ee64d21a 100644 --- a/test/transforms_v2_dispatcher_infos.py +++ b/test/transforms_v2_dispatcher_infos.py @@ -139,16 +139,6 @@ def fill_sequence_needs_broadcast(args_kwargs): DISPATCHER_INFOS = [ - DispatcherInfo( - F.crop, - kernels={ - tv_tensors.Image: F.crop_image, - tv_tensors.Video: F.crop_video, - tv_tensors.BoundingBoxes: F.crop_bounding_boxes, - tv_tensors.Mask: F.crop_mask, - }, - pil_kernel_info=PILKernelInfo(F._crop_image_pil, kernel_name="crop_image_pil"), - ), DispatcherInfo( F.resized_crop, kernels={ diff --git a/test/transforms_v2_kernel_infos.py b/test/transforms_v2_kernel_infos.py index afd4cc2e7f2..a549bfe72dd 100644 --- a/test/transforms_v2_kernel_infos.py +++ b/test/transforms_v2_kernel_infos.py @@ -259,105 +259,6 @@ def reference_inputs_convert_bounding_box_format(): ) -_CROP_PARAMS = combinations_grid(top=[-8, 0, 9], left=[-8, 0, 9], height=[12, 20], width=[12, 20]) - - -def sample_inputs_crop_image_tensor(): - for image_loader, params in itertools.product( - make_image_loaders(sizes=[(16, 17)], color_spaces=["RGB"], dtypes=[torch.float32]), - [ - dict(top=4, left=3, height=7, width=8), - dict(top=-1, left=3, height=7, width=8), - dict(top=4, left=-1, height=7, width=8), - dict(top=4, left=3, height=17, width=8), - dict(top=4, left=3, height=7, width=18), - ], - ): - yield ArgsKwargs(image_loader, **params) - - -def reference_inputs_crop_image_tensor(): - for image_loader, params in itertools.product( - make_image_loaders(extra_dims=[()], dtypes=[torch.uint8]), _CROP_PARAMS - ): - yield ArgsKwargs(image_loader, **params) - - -def sample_inputs_crop_bounding_boxes(): - for bounding_boxes_loader, params in itertools.product( - make_bounding_box_loaders(), [_CROP_PARAMS[0], _CROP_PARAMS[-1]] - ): - yield ArgsKwargs(bounding_boxes_loader, format=bounding_boxes_loader.format, **params) - - -def sample_inputs_crop_mask(): - for mask_loader in make_mask_loaders(sizes=[(16, 17)], num_categories=[10], num_objects=[5]): - yield ArgsKwargs(mask_loader, top=4, left=3, height=7, width=8) - - -def reference_inputs_crop_mask(): - for mask_loader, params in itertools.product(make_mask_loaders(extra_dims=[()], num_objects=[1]), _CROP_PARAMS): - yield ArgsKwargs(mask_loader, **params) - - -def sample_inputs_crop_video(): - for video_loader in make_video_loaders(sizes=[(16, 17)], num_frames=[3]): - yield ArgsKwargs(video_loader, top=4, left=3, height=7, width=8) - - -def reference_crop_bounding_boxes(bounding_boxes, *, format, top, left, height, width): - affine_matrix = np.array( - [ - [1, 0, -left], - [0, 1, -top], - ], - dtype="float64" if bounding_boxes.dtype == torch.float64 else "float32", - ) - - canvas_size = (height, width) - expected_bboxes = reference_affine_bounding_boxes_helper( - bounding_boxes, format=format, canvas_size=canvas_size, affine_matrix=affine_matrix - ) - return expected_bboxes, canvas_size - - -def reference_inputs_crop_bounding_boxes(): - for bounding_boxes_loader, params in itertools.product( - make_bounding_box_loaders(extra_dims=((), (4,))), [_CROP_PARAMS[0], _CROP_PARAMS[-1]] - ): - yield ArgsKwargs(bounding_boxes_loader, format=bounding_boxes_loader.format, **params) - - -KERNEL_INFOS.extend( - [ - KernelInfo( - F.crop_image, - kernel_name="crop_image_tensor", - sample_inputs_fn=sample_inputs_crop_image_tensor, - reference_fn=pil_reference_wrapper(F._crop_image_pil), - reference_inputs_fn=reference_inputs_crop_image_tensor, - float32_vs_uint8=True, - ), - KernelInfo( - F.crop_bounding_boxes, - sample_inputs_fn=sample_inputs_crop_bounding_boxes, - reference_fn=reference_crop_bounding_boxes, - reference_inputs_fn=reference_inputs_crop_bounding_boxes, - ), - KernelInfo( - F.crop_mask, - sample_inputs_fn=sample_inputs_crop_mask, - reference_fn=pil_reference_wrapper(F._crop_image_pil), - reference_inputs_fn=reference_inputs_crop_mask, - float32_vs_uint8=True, - ), - KernelInfo( - F.crop_video, - sample_inputs_fn=sample_inputs_crop_video, - ), - ] -) - _RESIZED_CROP_PARAMS = combinations_grid(top=[-8, 9], left=[-8, 9], height=[12], width=[12], size=[(16, 18)]) diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index 7838d7e3eae..8c74f600285 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -1165,7 +1165,7 @@ def pad_image( fill: Optional[Union[int, float, List[float]]] = None, padding_mode: str = "constant", ) -> torch.Tensor: - # Be aware that while `padding` has order `[left, top, right, bottom]` has order, `torch_padding` uses + # Be aware that while `padding` has order `[left, top, right, bottom]`, `torch_padding` uses # `[left, right, top, bottom]`. This stems from the fact that we align our API with PIL, but need to use `torch_pad` # internally. torch_padding = _parse_pad_padding(padding)