From 9562e2a583e4d9a7a5fca3be1aeea5d4c8780a8d Mon Sep 17 00:00:00 2001 From: Nicolas Granger Date: Sat, 29 Jul 2023 22:47:09 +0200 Subject: [PATCH] Replace stack/mask/reduce by indexing in _hsv_to_rgb Fixes #7753 --- .../transforms/v2/functional/_color.py | 20 +++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/torchvision/transforms/v2/functional/_color.py b/torchvision/transforms/v2/functional/_color.py index 4c087965f6c..93970cb79f1 100644 --- a/torchvision/transforms/v2/functional/_color.py +++ b/torchvision/transforms/v2/functional/_color.py @@ -317,14 +317,22 @@ def _hsv_to_rgb(img: torch.Tensor) -> torch.Tensor: p = one_minus_s.mul_(v).clamp_(0.0, 1.0) i.remainder_(6) - mask = i.unsqueeze(dim=-3) == torch.arange(6, device=i.device).view(-1, 1, 1) + vpqt = torch.stack((v, p, q, t), dim=-3) - a1 = torch.stack((v, q, p, p, t, v), dim=-3) - a2 = torch.stack((t, v, v, q, p, p), dim=-3) - a3 = torch.stack((p, p, t, v, v, q), dim=-3) - a4 = torch.stack((a1, a2, a3), dim=-4) + # vpqt -> rgb mapping based on i + select = torch.tensor( + [[0, 2, 1, 1, 3, 0], [3, 0, 0, 2, 1, 1], [1, 1, 3, 0, 0, 2]], dtype=torch.long + ) + select = select.to(device=img.device, non_blocking=True) + + select = select[:, i] + if select.ndim > 3: + # if input.shape is (B, ..., C, H, W) then + # select.shape is (C, B, ..., H, W) + # thus we move C axis to get (B, ..., C, H, W) + select = select.moveaxis(0, -3) - return (a4.mul_(mask.unsqueeze(dim=-4))).sum(dim=-3) + return vpqt.gather(-3, select) @_register_kernel_internal(adjust_hue, torch.Tensor)