Skip to content

Commit

Permalink
Replace stack/mask/reduce by indexing in _hsv_to_rgb
Browse files Browse the repository at this point in the history
Fixes #7753
  • Loading branch information
nlgranger committed Aug 12, 2023
1 parent cab01fc commit 9562e2a
Showing 1 changed file with 14 additions and 6 deletions.
20 changes: 14 additions & 6 deletions torchvision/transforms/v2/functional/_color.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,14 +317,22 @@ def _hsv_to_rgb(img: torch.Tensor) -> torch.Tensor:
p = one_minus_s.mul_(v).clamp_(0.0, 1.0)
i.remainder_(6)

mask = i.unsqueeze(dim=-3) == torch.arange(6, device=i.device).view(-1, 1, 1)
vpqt = torch.stack((v, p, q, t), dim=-3)

a1 = torch.stack((v, q, p, p, t, v), dim=-3)
a2 = torch.stack((t, v, v, q, p, p), dim=-3)
a3 = torch.stack((p, p, t, v, v, q), dim=-3)
a4 = torch.stack((a1, a2, a3), dim=-4)
# vpqt -> rgb mapping based on i
select = torch.tensor(
[[0, 2, 1, 1, 3, 0], [3, 0, 0, 2, 1, 1], [1, 1, 3, 0, 0, 2]], dtype=torch.long
)
select = select.to(device=img.device, non_blocking=True)

select = select[:, i]
if select.ndim > 3:
# if input.shape is (B, ..., C, H, W) then
# select.shape is (C, B, ..., H, W)
# thus we move C axis to get (B, ..., C, H, W)
select = select.moveaxis(0, -3)

return (a4.mul_(mask.unsqueeze(dim=-4))).sum(dim=-3)
return vpqt.gather(-3, select)


@_register_kernel_internal(adjust_hue, torch.Tensor)
Expand Down

0 comments on commit 9562e2a

Please sign in to comment.