Skip to content

Commit

Permalink
Replace stack/mask/reduce by indexing in _hsv_to_rgb
Browse files Browse the repository at this point in the history
Fixes #7753
  • Loading branch information
nlgranger committed Jul 29, 2023
1 parent b9b7cfc commit 5902f34
Showing 1 changed file with 16 additions and 8 deletions.
24 changes: 16 additions & 8 deletions torchvision/transforms/v2/functional/_color.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,14 +325,22 @@ def _hsv_to_rgb(img: torch.Tensor) -> torch.Tensor:
p = one_minus_s.mul_(v).clamp_(0.0, 1.0)
i.remainder_(6)

mask = i.unsqueeze(dim=-3) == torch.arange(6, device=i.device).view(-1, 1, 1)

a1 = torch.stack((v, q, p, p, t, v), dim=-3)
a2 = torch.stack((t, v, v, q, p, p), dim=-3)
a3 = torch.stack((p, p, t, v, v, q), dim=-3)
a4 = torch.stack((a1, a2, a3), dim=-4)

return (a4.mul_(mask.unsqueeze(dim=-4))).sum(dim=-3)
vpqt = torch.stack((v, p, q, t), dim=-3)

# vpqt -> rgb mapping based on i
select = torch.tensor([[0, 2, 1, 1, 3, 0], [3, 0, 0, 2, 1, 1], [1, 1, 3, 0, 0, 2]], dtype=torch.long)
if img.device.type == "cuda":
select = select.pin_memory(img.device)
select = select.to(device=img.device, non_blocking=True)

select = select[:, i]
if select.ndim > 3:
# if input.shape is (B, ..., C, H, W) then
# select.shape is (C, B, ..., H, W)
# thus we move C axis to get (B, ..., C, H, W)
select = select.moveaxis(0, -3)

return vpqt.gather(-3, select)


def adjust_hue_image_tensor(image: torch.Tensor, hue_factor: float) -> torch.Tensor:
Expand Down

0 comments on commit 5902f34

Please sign in to comment.