From 46a47bce861e2b1ff5321da2cb30111e0d0673bf Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 24 May 2023 15:32:57 +0200 Subject: [PATCH 1/8] add PermuteChannels transform --- test/test_transforms_v2.py | 1 + torchvision/datapoints/_datapoint.py | 3 ++ torchvision/datapoints/_image.py | 4 ++ torchvision/datapoints/_video.py | 4 ++ torchvision/transforms/v2/__init__.py | 1 + torchvision/transforms/v2/_color.py | 34 +++++++------- .../transforms/v2/functional/__init__.py | 4 ++ .../transforms/v2/functional/_color.py | 45 ++++++++++++++++++- 8 files changed, 78 insertions(+), 18 deletions(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 02e3e1e569a..e4d6b8754ba 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -126,6 +126,7 @@ class TestSmoke: (transforms.RandomEqualize(p=1.0), None), (transforms.RandomGrayscale(p=1.0), None), (transforms.RandomInvert(p=1.0), None), + (transforms.PermuteChannels(), None), (transforms.RandomPhotometricDistort(p=1.0), None), (transforms.RandomPosterize(bits=4, p=1.0), None), (transforms.RandomSolarize(threshold=0.5, p=1.0), None), diff --git a/torchvision/datapoints/_datapoint.py b/torchvision/datapoints/_datapoint.py index fe489d13ea0..c2fb90b202a 100644 --- a/torchvision/datapoints/_datapoint.py +++ b/torchvision/datapoints/_datapoint.py @@ -254,6 +254,9 @@ def invert(self) -> Datapoint: def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = None) -> Datapoint: return self + def permute_channels(self, permutation: List[int]) -> Datapoint: + return self + _InputType = Union[torch.Tensor, PIL.Image.Image, Datapoint] _InputTypeJIT = torch.Tensor diff --git a/torchvision/datapoints/_image.py b/torchvision/datapoints/_image.py index e47a6c10fc3..59aa35bcd6f 100644 --- a/torchvision/datapoints/_image.py +++ b/torchvision/datapoints/_image.py @@ -253,6 +253,10 @@ def normalize(self, mean: List[float], std: List[float], inplace: bool = False) output = self._F.normalize_image_tensor(self.as_subclass(torch.Tensor), mean=mean, std=std, inplace=inplace) return Image.wrap_like(self, output) + def permute_channels(self, permutation: List[int]) -> Image: + output = self._F.permute_channels_image_tensor(self.as_subclass(torch.Tensor), permutation=permutation) + return Image.wrap_like(self, output) + _ImageType = Union[torch.Tensor, PIL.Image.Image, Image] _ImageTypeJIT = torch.Tensor diff --git a/torchvision/datapoints/_video.py b/torchvision/datapoints/_video.py index a6fbe2bd473..62afff52aac 100644 --- a/torchvision/datapoints/_video.py +++ b/torchvision/datapoints/_video.py @@ -243,6 +243,10 @@ def normalize(self, mean: List[float], std: List[float], inplace: bool = False) output = self._F.normalize_video(self.as_subclass(torch.Tensor), mean=mean, std=std, inplace=inplace) return Video.wrap_like(self, output) + def permute_channels(self, permutation: List[int]) -> Video: + output = self._F.permute_channels_image_tensor(self.as_subclass(torch.Tensor), permutation=permutation) + return Video.wrap_like(self, output) + _VideoType = Union[torch.Tensor, Video] _VideoTypeJIT = torch.Tensor diff --git a/torchvision/transforms/v2/__init__.py b/torchvision/transforms/v2/__init__.py index 6573446a33a..47cf0e47df7 100644 --- a/torchvision/transforms/v2/__init__.py +++ b/torchvision/transforms/v2/__init__.py @@ -9,6 +9,7 @@ from ._color import ( ColorJitter, Grayscale, + PermuteChannels, RandomAdjustSharpness, RandomAutocontrast, RandomEqualize, diff --git a/torchvision/transforms/v2/_color.py b/torchvision/transforms/v2/_color.py index 7dd8eeae236..1d08d251a48 100644 --- a/torchvision/transforms/v2/_color.py +++ b/torchvision/transforms/v2/_color.py @@ -177,7 +177,22 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return output -# TODO: This class seems to be untested +class PermuteChannels(Transform): + _transformed_types = ( + datapoints.Image, + PIL.Image.Image, + is_simple_tensor, + datapoints.Video, + ) + + def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + num_channels, *_ = query_chw(flat_inputs) + return dict(permutation=torch.randperm(num_channels).tolist()) + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return F.permute_channels(inpt, params["permutation"]) + + class RandomPhotometricDistort(Transform): """[BETA] Randomly distorts the image or video as used in `SSD: Single Shot MultiBox Detector `_. @@ -241,21 +256,6 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: params["channel_permutation"] = torch.randperm(num_channels) if torch.rand(1) < self.p else None return params - def _permute_channels( - self, inpt: Union[datapoints._ImageType, datapoints._VideoType], permutation: torch.Tensor - ) -> Union[datapoints._ImageType, datapoints._VideoType]: - orig_inpt = inpt - if isinstance(orig_inpt, PIL.Image.Image): - inpt = F.pil_to_tensor(inpt) - - # TODO: Find a better fix than as_subclass??? - output = inpt[..., permutation, :, :].as_subclass(type(inpt)) - - if isinstance(orig_inpt, PIL.Image.Image): - output = F.to_image_pil(output) - - return output - def _transform( self, inpt: Union[datapoints._ImageType, datapoints._VideoType], params: Dict[str, Any] ) -> Union[datapoints._ImageType, datapoints._VideoType]: @@ -270,7 +270,7 @@ def _transform( if params["contrast_factor"] is not None and not params["contrast_before"]: inpt = F.adjust_contrast(inpt, contrast_factor=params["contrast_factor"]) if params["channel_permutation"] is not None: - inpt = self._permute_channels(inpt, permutation=params["channel_permutation"]) + inpt = F.permute_channels(inpt, permutation=params["channel_permutation"]) return inpt diff --git a/torchvision/transforms/v2/functional/__init__.py b/torchvision/transforms/v2/functional/__init__.py index ffb34c87748..91516972f15 100644 --- a/torchvision/transforms/v2/functional/__init__.py +++ b/torchvision/transforms/v2/functional/__init__.py @@ -65,6 +65,10 @@ invert_image_pil, invert_image_tensor, invert_video, + permute_channels, + permute_channels_image_pil, + permute_channels_image_tensor, + permute_channels_video, posterize, posterize_image_pil, posterize_image_tensor, diff --git a/torchvision/transforms/v2/functional/_color.py b/torchvision/transforms/v2/functional/_color.py index 4ba7e5b36b3..e1b08efa3e9 100644 --- a/torchvision/transforms/v2/functional/_color.py +++ b/torchvision/transforms/v2/functional/_color.py @@ -1,4 +1,4 @@ -from typing import Union +from typing import List, Union import PIL.Image import torch @@ -10,6 +10,7 @@ from torchvision.utils import _log_api_usage_once from ._meta import _num_value_bits, convert_dtype_image_tensor +from ._type_conversion import pil_to_tensor, to_image_pil from ._utils import is_simple_tensor @@ -670,3 +671,45 @@ def invert(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT: f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " f"but got {type(inpt)} instead." ) + + +def permute_channels_image_tensor(image: torch.Tensor, permutation: List[int]) -> torch.Tensor: + shape = image.shape + num_channels, height, width = shape[-3:] + + if len(permutation) != num_channels: + raise ValueError( + f"Length of permutation does not match number of channels: " f"{len(permutation)} != {num_channels}" + ) + + if image.numel() == 0: + return image + + image = image.reshape(-1, num_channels, height, width) + image = image[:, permutation, :, :] + return image.reshape(shape) + + +def permute_channels_image_pil(image: PIL.Image.Image, permutation: List[int]) -> PIL.Image: + return to_image_pil(permute_channels_image_tensor(pil_to_tensor(image), permutation=permutation)) + + +def permute_channels_video(video: torch.Tensor, permutation: List[int]) -> torch.Tensor: + return permute_channels_image_tensor(video, permutation=permutation) + + +def permute_channels(inpt: datapoints._InputTypeJIT, permutation: List[int]) -> datapoints._InputTypeJIT: + if not torch.jit.is_scripting(): + _log_api_usage_once(permute_channels) + + if torch.jit.is_scripting() or is_simple_tensor(inpt): + return permute_channels_image_tensor(inpt, permutation=permutation) + elif isinstance(inpt, datapoints._datapoint.Datapoint): + return inpt.permute_channels(permutation=permutation) + elif isinstance(inpt, PIL.Image.Image): + return permute_channels_image_pil(inpt, permutation=permutation) + else: + raise TypeError( + f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " + f"but got {type(inpt)} instead." + ) From e4be568959c86623d9e46ae9968e336fbecce6a9 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 7 Aug 2023 13:00:59 +0200 Subject: [PATCH 2/8] cleanup --- torchvision/transforms/v2/functional/_color.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/transforms/v2/functional/_color.py b/torchvision/transforms/v2/functional/_color.py index bd94c8339f2..01fcf6913ac 100644 --- a/torchvision/transforms/v2/functional/_color.py +++ b/torchvision/transforms/v2/functional/_color.py @@ -647,7 +647,7 @@ def invert_video(video: torch.Tensor) -> torch.Tensor: def permute_channels(inpt: datapoints._InputTypeJIT, permutation: List[int]) -> datapoints._InputTypeJIT: if torch.jit.is_scripting(): - return invert_image_tensor(inpt, permutation=permutation) + return permute_channels_image_tensor(inpt, permutation=permutation) _log_api_usage_once(permute_channels) From 8b94f6d69b07ee8dac126de5c981502e1ac04545 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 7 Aug 2023 13:15:54 +0200 Subject: [PATCH 3/8] add tests --- test/test_transforms_v2_refactored.py | 57 +++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) diff --git a/test/test_transforms_v2_refactored.py b/test/test_transforms_v2_refactored.py index c910882f9fd..ff67e8a7e9f 100644 --- a/test/test_transforms_v2_refactored.py +++ b/test/test_transforms_v2_refactored.py @@ -1,6 +1,7 @@ import contextlib import decimal import inspect +import itertools import math import re from unittest import mock @@ -2280,3 +2281,59 @@ def resize_my_datapoint(): _register_kernel_internal(F.resize, MyDatapoint, datapoint_wrapper=False)(resize_my_datapoint) assert _get_kernel(F.resize, MyDatapoint) is resize_my_datapoint + + +class TestPermuteChannels: + _CORRECTNESS_PERMUTATIONS = list(itertools.permutations(range(3))) + _DEFAULT_PERMUTATION = _CORRECTNESS_PERMUTATIONS[-1] + + @pytest.mark.parametrize( + ("kernel", "make_input"), + [ + (F.permute_channels_image_tensor, make_image_tensor), + # FIXME + # check_kernel does not support PIL kernel, but it should + (F.permute_channels_image_tensor, make_image), + (F.permute_channels_video, make_video), + ], + ) + @pytest.mark.parametrize("dtype", [torch.float32, torch.uint8]) + @pytest.mark.parametrize("device", cpu_and_cuda()) + def test_kernel(self, kernel, make_input, dtype, device): + check_kernel(kernel, make_input(dtype=dtype, device=device), permutation=self._DEFAULT_PERMUTATION) + + @pytest.mark.parametrize( + ("kernel", "make_input"), + [ + (F.permute_channels_image_tensor, make_image_tensor), + (F.permute_channels_image_pil, make_image_pil), + (F.permute_channels_image_tensor, make_image), + (F.permute_channels_video, make_video), + ], + ) + def test_dispatcher(self, kernel, make_input): + check_dispatcher(F.permute_channels, kernel, make_input(), permutation=self._DEFAULT_PERMUTATION) + + @pytest.mark.parametrize( + ("kernel", "input_type"), + [ + (F.adjust_brightness_image_tensor, torch.Tensor), + (F.adjust_brightness_image_pil, PIL.Image.Image), + (F.adjust_brightness_image_tensor, datapoints.Image), + (F.adjust_brightness_video, datapoints.Video), + ], + ) + def test_dispatcher_signature(self, kernel, input_type): + check_dispatcher_kernel_signature_match(F.adjust_brightness, kernel=kernel, input_type=input_type) + + def reference_image_correctness(self, image, permutation): + return datapoints.Image(image.numpy()[permutation, ...]) + + @pytest.mark.parametrize("permutation", _CORRECTNESS_PERMUTATIONS) + def test_image_correctness(self, permutation): + image = make_image() + + actual = F.permute_channels(image, permutation=permutation) + expected = self.reference_image_correctness(image, permutation=permutation) + + torch.testing.assert_close(actual, expected) From fdbb4adc96eba14f0cd2c731f7605b67bfb884d4 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 7 Aug 2023 14:11:12 +0200 Subject: [PATCH 4/8] fix tests --- test/test_transforms_v2_refactored.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/test/test_transforms_v2_refactored.py b/test/test_transforms_v2_refactored.py index ff67e8a7e9f..10b7dd91581 100644 --- a/test/test_transforms_v2_refactored.py +++ b/test/test_transforms_v2_refactored.py @@ -2317,14 +2317,14 @@ def test_dispatcher(self, kernel, make_input): @pytest.mark.parametrize( ("kernel", "input_type"), [ - (F.adjust_brightness_image_tensor, torch.Tensor), - (F.adjust_brightness_image_pil, PIL.Image.Image), - (F.adjust_brightness_image_tensor, datapoints.Image), - (F.adjust_brightness_video, datapoints.Video), + (F.permute_channels_image_tensor, torch.Tensor), + (F.permute_channels_image_pil, PIL.Image.Image), + (F.permute_channels_image_tensor, datapoints.Image), + (F.permute_channels_video, datapoints.Video), ], ) def test_dispatcher_signature(self, kernel, input_type): - check_dispatcher_kernel_signature_match(F.adjust_brightness, kernel=kernel, input_type=input_type) + check_dispatcher_kernel_signature_match(F.permute_channels, kernel=kernel, input_type=input_type) def reference_image_correctness(self, image, permutation): return datapoints.Image(image.numpy()[permutation, ...]) From 95f7d4680ed03546cae96d0f1ed2653ef88ea214 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 7 Aug 2023 14:13:37 +0200 Subject: [PATCH 5/8] fix tests --- torchvision/transforms/v2/functional/_color.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchvision/transforms/v2/functional/_color.py b/torchvision/transforms/v2/functional/_color.py index 01fcf6913ac..ad530600708 100644 --- a/torchvision/transforms/v2/functional/_color.py +++ b/torchvision/transforms/v2/functional/_color.py @@ -645,6 +645,7 @@ def invert_video(video: torch.Tensor) -> torch.Tensor: return invert_image_tensor(video) +@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) def permute_channels(inpt: datapoints._InputTypeJIT, permutation: List[int]) -> datapoints._InputTypeJIT: if torch.jit.is_scripting(): return permute_channels_image_tensor(inpt, permutation=permutation) From a5708e4a88028195992b6e473d209a75d446fa65 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 9 Aug 2023 09:29:52 +0200 Subject: [PATCH 6/8] address comments --- docs/source/transforms.rst | 1 + test/test_transforms_v2.py | 2 +- test/test_transforms_v2_refactored.py | 17 +++++++++-------- torchvision/transforms/v2/__init__.py | 2 +- torchvision/transforms/v2/_color.py | 7 ++++++- torchvision/transforms/v2/functional/_color.py | 15 +++++++++++++++ 6 files changed, 33 insertions(+), 11 deletions(-) diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index a1858c6b514..0df46c92530 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -155,6 +155,7 @@ Color ColorJitter v2.ColorJitter + v2.RandomChannelPermutation v2.RandomPhotometricDistort Grayscale v2.Grayscale diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index ef2cc2a1270..5f4a9b62898 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -124,7 +124,7 @@ class TestSmoke: (transforms.RandomEqualize(p=1.0), None), (transforms.RandomGrayscale(p=1.0), None), (transforms.RandomInvert(p=1.0), None), - (transforms.RandomPermuteChannels(), None), + (transforms.RandomChannelPermutation(), None), (transforms.RandomPhotometricDistort(p=1.0), None), (transforms.RandomPosterize(bits=4, p=1.0), None), (transforms.RandomSolarize(threshold=0.5, p=1.0), None), diff --git a/test/test_transforms_v2_refactored.py b/test/test_transforms_v2_refactored.py index 10b7dd91581..fa04d5deb0c 100644 --- a/test/test_transforms_v2_refactored.py +++ b/test/test_transforms_v2_refactored.py @@ -1,7 +1,6 @@ import contextlib import decimal import inspect -import itertools import math import re from unittest import mock @@ -2284,8 +2283,7 @@ def resize_my_datapoint(): class TestPermuteChannels: - _CORRECTNESS_PERMUTATIONS = list(itertools.permutations(range(3))) - _DEFAULT_PERMUTATION = _CORRECTNESS_PERMUTATIONS[-1] + _DEFAULT_PERMUTATION = [2, 0, 1] @pytest.mark.parametrize( ("kernel", "make_input"), @@ -2327,11 +2325,14 @@ def test_dispatcher_signature(self, kernel, input_type): check_dispatcher_kernel_signature_match(F.permute_channels, kernel=kernel, input_type=input_type) def reference_image_correctness(self, image, permutation): - return datapoints.Image(image.numpy()[permutation, ...]) - - @pytest.mark.parametrize("permutation", _CORRECTNESS_PERMUTATIONS) - def test_image_correctness(self, permutation): - image = make_image() + channel_images = image.split(1, dim=-3) + permuted_channel_images = [channel_images[channel_idx] for channel_idx in permutation] + return datapoints.Image(torch.concat(permuted_channel_images, dim=-3)) + + @pytest.mark.parametrize("permutation", [[2, 0, 1], [1, 2, 0], [2, 0, 1], [0, 1, 2]]) + @pytest.mark.parametrize("batch_dims", [(), (2,), (2, 1)]) + def test_image_correctness(self, permutation, batch_dims): + image = make_image(batch_dims=batch_dims) actual = F.permute_channels(image, permutation=permutation) expected = self.reference_image_correctness(image, permutation=permutation) diff --git a/torchvision/transforms/v2/__init__.py b/torchvision/transforms/v2/__init__.py index 432bdc21577..4451cb7a1a2 100644 --- a/torchvision/transforms/v2/__init__.py +++ b/torchvision/transforms/v2/__init__.py @@ -11,10 +11,10 @@ Grayscale, RandomAdjustSharpness, RandomAutocontrast, + RandomChannelPermutation, RandomEqualize, RandomGrayscale, RandomInvert, - RandomPermuteChannels, RandomPhotometricDistort, RandomPosterize, RandomSolarize, diff --git a/torchvision/transforms/v2/_color.py b/torchvision/transforms/v2/_color.py index 3bc01c7ba28..1543a43c9eb 100644 --- a/torchvision/transforms/v2/_color.py +++ b/torchvision/transforms/v2/_color.py @@ -177,7 +177,12 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return output -class RandomPermuteChannels(Transform): +class RandomChannelPermutation(Transform): + """[BETA] Randomly permute the channels of an image or video + + .. v2betastatus:: RandomChannelPermutation transform + """ + _transformed_types = ( datapoints.Image, PIL.Image.Image, diff --git a/torchvision/transforms/v2/functional/_color.py b/torchvision/transforms/v2/functional/_color.py index ad530600708..777463a8efb 100644 --- a/torchvision/transforms/v2/functional/_color.py +++ b/torchvision/transforms/v2/functional/_color.py @@ -647,6 +647,21 @@ def invert_video(video: torch.Tensor) -> torch.Tensor: @_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) def permute_channels(inpt: datapoints._InputTypeJIT, permutation: List[int]) -> datapoints._InputTypeJIT: + """Permute the channels of the input according to the given permutation. + + This function supports plain :class:`~torch.Tensor`'s, :class:`PIL.Image.Image`'s, and + :class:`torchvision.datapoints.Image` and :class:`torchvision.datapoints.Video`. + + Example: + >>> rgb_image = torch.rand(3, 256, 256) + >>> bgr_image = F.permutate_channels(rgb_image, permutation=[2, 1, 0]) + + Args: + permutation (List[int]): Valid permutation of the input channel indices. + + Raises: + ValueError: If ``len(permutation)`` doesn't match the number of channels in the input. + """ if torch.jit.is_scripting(): return permute_channels_image_tensor(inpt, permutation=permutation) From ec8546ab81acffe5c827ddb1fc4e83b699c5621f Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 9 Aug 2023 09:35:14 +0200 Subject: [PATCH 7/8] remove the .tolist() call in the transform --- torchvision/transforms/v2/_color.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/transforms/v2/_color.py b/torchvision/transforms/v2/_color.py index 1543a43c9eb..8315e2f36b4 100644 --- a/torchvision/transforms/v2/_color.py +++ b/torchvision/transforms/v2/_color.py @@ -192,7 +192,7 @@ class RandomChannelPermutation(Transform): def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: num_channels, *_ = query_chw(flat_inputs) - return dict(permutation=torch.randperm(num_channels).tolist()) + return dict(permutation=torch.randperm(num_channels)) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return F.permute_channels(inpt, params["permutation"]) From 72c3a937b2e9b8b75a7fee6654b7c7db1c09cbdb Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 9 Aug 2023 10:44:13 +0200 Subject: [PATCH 8/8] improve functional doc --- torchvision/transforms/v2/functional/_color.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/torchvision/transforms/v2/functional/_color.py b/torchvision/transforms/v2/functional/_color.py index 777463a8efb..9b6bf3886fa 100644 --- a/torchvision/transforms/v2/functional/_color.py +++ b/torchvision/transforms/v2/functional/_color.py @@ -657,7 +657,13 @@ def permute_channels(inpt: datapoints._InputTypeJIT, permutation: List[int]) -> >>> bgr_image = F.permutate_channels(rgb_image, permutation=[2, 1, 0]) Args: - permutation (List[int]): Valid permutation of the input channel indices. + permutation (List[int]): Valid permutation of the input channel indices. The index of the element determines the + channel index in the input and the value determines the channel index in the output. For example, + ``permutation=[2, 0 , 1]`` + + - takes ``ìnpt[..., 0, :, :]`` and puts it at ``output[..., 2, :, :]``, + - takes ``ìnpt[..., 1, :, :]`` and puts it at ``output[..., 0, :, :]``, and + - takes ``ìnpt[..., 2, :, :]`` and puts it at ``output[..., 1, :, :]``. Raises: ValueError: If ``len(permutation)`` doesn't match the number of channels in the input.