diff --git a/piqa/gmsd.py b/piqa/gmsd.py index 5016e7e..8986ef4 100644 --- a/piqa/gmsd.py +++ b/piqa/gmsd.py @@ -1,12 +1,18 @@ r"""Gradient Magnitude Similarity Deviation (GMSD) +and Multi-Scale Gradient Magnitude Similiarity Deviation (MS-GMSD) -This module implements the GMSD in PyTorch. +This module implements the GMSD and MS-GMSD in PyTorch. References: [1] Gradient Magnitude Similarity Deviation: An Highly Efficient Perceptual Image Quality Index (Xue et al., 2013) https://arxiv.org/abs/1308.3052 + + [2] Gradient Magnitude Similarity Deviation on + multiple scales for color image quality assessment + (Zhang et al., 2017) + https://ieeexplore.ieee.org/document/7952357 """ import torch @@ -16,6 +22,7 @@ from piqa.utils import build_reduce, prewitt_kernel, filter2d, tensor_norm _L_WEIGHTS = torch.FloatTensor([0.299, 0.587, 0.114]) +_MS_WEIGHTS = torch.FloatTensor([0.096, 0.596, 0.289, 0.019]) def _gmsd( @@ -24,6 +31,7 @@ def _gmsd( kernel: torch.Tensor, value_range: float = 1., c: float = 0.00261, # 170. / (255. ** 2) + alpha: float = 0., ) -> torch.Tensor: r"""Returns the GMSD between `x` and `y`, without downsampling and color space conversion. @@ -36,7 +44,7 @@ def _gmsd( kernel: A 2D gradient kernel, (2, 1, K, K). value_range: The value range of the inputs (usually 1. or 255). - For the remaining arguments, refer to [1]. + For the remaining arguments, refer to [1] and [2]. Example: >>> x = torch.rand(5, 1, 256, 256) @@ -55,10 +63,14 @@ def _gmsd( gm_x = tensor_norm(filter2d(x, kernel, padding=pad), dim=1) gm_y = tensor_norm(filter2d(y, kernel, padding=pad), dim=1) + gm_xy = gm_x * gm_y + # Gradient magnitude similarity - gms = (2. * gm_x * gm_y + c) / (gm_x ** 2 + gm_y ** 2 + c) + gms_num = (2. - alpha) * gm_xy + c + gms_den = gm_x ** 2 + gm_y ** 2 - alpha * gm_xy + c + gms = gms_num / gms_den - # Gradient magnitude similarity diviation + # Gradient magnitude similarity deviation gmsd = (gms - gms.mean((-1, -2), keepdim=True)) ** 2 gmsd = torch.sqrt(gmsd.mean((-1, -2))) @@ -108,6 +120,98 @@ def gmsd( return _gmsd(x, y, kernel, **kwargs) +def _msgmsd( + x: torch.Tensor, + y: torch.Tensor, + kernel: torch.Tensor, + weights: torch.Tensor, + alpha: float = 0.5, + **kwargs, +) -> torch.Tensor: + r"""Returns the MS-GMSD between `x` and `y`, + without color space conversion. + + `_msgmsd` is an auxiliary function for `msgmsd` and `MSGMSD`. + + Args: + x: An input tensor, (N, 1, H, W). + y: A target tensor, (N, 1, H, W). + kernel: A 2D gradient kernel, (2, 1, K, K). + weights: The weights of the scales, (M,). + + `alpha` and `**kwargs` are transmitted to `_gmsd`. + + Example: + >>> x = torch.rand(5, 1, 256, 256) + >>> y = torch.rand(5, 1, 256, 256) + >>> kernel = torch.rand(2, 1, 3, 3) + >>> weights = torch.rand(4) + >>> l = _msgmsd(x, y, kernel, weights) + >>> l.size() + torch.Size([5]) + """ + + gmsds = [] + + for i in range(weights.numel()): + if i > 0: + x = F.avg_pool2d(x, kernel_size=2, ceil_mode=True) + y = F.avg_pool2d(y, kernel_size=2, ceil_mode=True) + + gmsds.append(_gmsd(x, y, kernel, alpha=alpha, **kwargs)) + + msgmsd = torch.stack(gmsds, dim=-1) ** 2 + msgmsd = torch.sqrt((msgmsd * weights).sum(dim=-1)) + + return msgmsd + + +def msgmsd( + x: torch.Tensor, + y: torch.Tensor, + kernel: torch.Tensor = None, + weights: torch.Tensor = None, + **kwargs, +) -> torch.Tensor: + r"""Returns the MS-GMSD between `x` and `y`. + + Args: + x: An input tensor, (N, 3, H, W). + y: A target tensor, (N, 3, H, W). + kernel: A 2D gradient kernel, (2, 1, K, K). + If `None`, use the Prewitt kernel instead. + weights: The weights of the scales, (M,). + If `None`, use the official weights instead. + + `**kwargs` are transmitted to `_msgmsd`. + + Example: + >>> x = torch.rand(5, 3, 256, 256) + >>> y = torch.rand(5, 3, 256, 256) + >>> l = msgmsd(x, y) + >>> l.size() + torch.Size([5]) + """ + + # RGB to luminance + l_weights = _L_WEIGHTS.to(x.device).view(1, 3, 1, 1) + + x = F.conv2d(x, l_weights) + y = F.conv2d(y, l_weights) + + # Kernel + if kernel is None: + kernel = prewitt_kernel() + kernel = torch.stack([kernel, kernel.t()]).unsqueeze(1) + kernel = kernel.to(x.device) + + # Weights + if weights is None: + weights = _MS_WEIGHTS.to(x.device) + + return _msgmsd(x, y, kernel, weights, **kwargs) + + class GMSD(nn.Module): r"""Creates a criterion that measures the GMSD between an input and a target. @@ -173,3 +277,73 @@ def forward( l = _gmsd(input, target, self.kernel, **self.kwargs) return self.reduce(l) + + +class MSGMSD(nn.Module): + r"""Creates a criterion that measures the MSGMSD + between an input and a target. + + Args: + kernel: A 2D gradient kernel, (2, 1, K, K). + If `None`, use the Prewitt kernel instead. + weights: The weights of the scales, (M,). + If `None`, use the official weights instead. + reduction: Specifies the reduction to apply to the output: + `'none'` | `'mean'` | `'sum'`. + + `**kwargs` are transmitted to `_msgmsd`. + + Shape: + * Input: (N, 3, H, W) + * Target: (N, 3, H, W) + * Output: (N,) or (1,) depending on `reduction` + + Example: + >>> criterion = MSGMSD().cuda() + >>> x = torch.rand(5, 3, 256, 256).cuda() + >>> y = torch.rand(5, 3, 256, 256).cuda() + >>> l = criterion(x, y) + >>> l.size() + torch.Size([]) + """ + + def __init__( + self, + kernel: torch.Tensor = None, + weights: torch.Tensor = None, + reduction: str = 'mean', + **kwargs, + ): + r"""""" + super().__init__() + + if kernel is None: + kernel = prewitt_kernel() + kernel = torch.stack([kernel, kernel.t()]).unsqueeze(1) + + if weights is None: + weights = _MS_WEIGHTS + + self.register_buffer('kernel', kernel) + self.register_buffer('weights', weights) + self.register_buffer('l_weights', _L_WEIGHTS.view(1, 3, 1, 1)) + + self.reduce = build_reduce(reduction) + self.kwargs = kwargs + + def forward( + self, + input: torch.Tensor, + target: torch.Tensor, + ) -> torch.Tensor: + r"""Defines the computation performed at every call. + """ + + # RGB to luminance + input = F.conv2d(input, self.l_weights) + target = F.conv2d(target, self.l_weights) + + # MSGMSD + l = _msgmsd(input, target, self.kernel, self.weights, **self.kwargs) + + return self.reduce(l) diff --git a/piqa/ssim.py b/piqa/ssim.py index a59a150..7a3467a 100644 --- a/piqa/ssim.py +++ b/piqa/ssim.py @@ -26,7 +26,7 @@ from typing import Tuple -_WEIGHTS = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]) +_MS_WEIGHTS = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]) def create_window(window_size: int, n_channels: int) -> torch.Tensor: @@ -209,7 +209,7 @@ def msssim( window = create_window(window_size, n_channels).to(x.device) if weights is None: - weights = _WEIGHTS.to(x.device) + weights = _MS_WEIGHTS.to(x.device) return msssim_per_channel(x, y, window, weights, **kwargs).mean(-1) @@ -313,7 +313,7 @@ def __init__( super().__init__() if weights is None: - weights = _WEIGHTS + weights = _MS_WEIGHTS self.register_buffer('window', create_window(window_size, n_channels)) self.register_buffer('weights', weights)