Skip to content

Commit

Permalink
✨ Multi-Scale Gradient Magnitude Similarity Deviation (MS-GMSD)
Browse files Browse the repository at this point in the history
  • Loading branch information
francois-rozet committed Dec 10, 2020
1 parent e6179ae commit 2364fbc
Show file tree
Hide file tree
Showing 2 changed files with 181 additions and 7 deletions.
182 changes: 178 additions & 4 deletions piqa/gmsd.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
r"""Gradient Magnitude Similarity Deviation (GMSD)
and Multi-Scale Gradient Magnitude Similiarity Deviation (MS-GMSD)
This module implements the GMSD in PyTorch.
This module implements the GMSD and MS-GMSD in PyTorch.
References:
[1] Gradient Magnitude Similarity Deviation:
An Highly Efficient Perceptual Image Quality Index
(Xue et al., 2013)
https://arxiv.org/abs/1308.3052
[2] Gradient Magnitude Similarity Deviation on
multiple scales for color image quality assessment
(Zhang et al., 2017)
https://ieeexplore.ieee.org/document/7952357
"""

import torch
Expand All @@ -16,6 +22,7 @@
from piqa.utils import build_reduce, prewitt_kernel, filter2d, tensor_norm

_L_WEIGHTS = torch.FloatTensor([0.299, 0.587, 0.114])
_MS_WEIGHTS = torch.FloatTensor([0.096, 0.596, 0.289, 0.019])


def _gmsd(
Expand All @@ -24,6 +31,7 @@ def _gmsd(
kernel: torch.Tensor,
value_range: float = 1.,
c: float = 0.00261, # 170. / (255. ** 2)
alpha: float = 0.,
) -> torch.Tensor:
r"""Returns the GMSD between `x` and `y`,
without downsampling and color space conversion.
Expand All @@ -36,7 +44,7 @@ def _gmsd(
kernel: A 2D gradient kernel, (2, 1, K, K).
value_range: The value range of the inputs (usually 1. or 255).
For the remaining arguments, refer to [1].
For the remaining arguments, refer to [1] and [2].
Example:
>>> x = torch.rand(5, 1, 256, 256)
Expand All @@ -55,10 +63,14 @@ def _gmsd(
gm_x = tensor_norm(filter2d(x, kernel, padding=pad), dim=1)
gm_y = tensor_norm(filter2d(y, kernel, padding=pad), dim=1)

gm_xy = gm_x * gm_y

# Gradient magnitude similarity
gms = (2. * gm_x * gm_y + c) / (gm_x ** 2 + gm_y ** 2 + c)
gms_num = (2. - alpha) * gm_xy + c
gms_den = gm_x ** 2 + gm_y ** 2 - alpha * gm_xy + c
gms = gms_num / gms_den

# Gradient magnitude similarity diviation
# Gradient magnitude similarity deviation
gmsd = (gms - gms.mean((-1, -2), keepdim=True)) ** 2
gmsd = torch.sqrt(gmsd.mean((-1, -2)))

Expand Down Expand Up @@ -108,6 +120,98 @@ def gmsd(
return _gmsd(x, y, kernel, **kwargs)


def _msgmsd(
x: torch.Tensor,
y: torch.Tensor,
kernel: torch.Tensor,
weights: torch.Tensor,
alpha: float = 0.5,
**kwargs,
) -> torch.Tensor:
r"""Returns the MS-GMSD between `x` and `y`,
without color space conversion.
`_msgmsd` is an auxiliary function for `msgmsd` and `MSGMSD`.
Args:
x: An input tensor, (N, 1, H, W).
y: A target tensor, (N, 1, H, W).
kernel: A 2D gradient kernel, (2, 1, K, K).
weights: The weights of the scales, (M,).
`alpha` and `**kwargs` are transmitted to `_gmsd`.
Example:
>>> x = torch.rand(5, 1, 256, 256)
>>> y = torch.rand(5, 1, 256, 256)
>>> kernel = torch.rand(2, 1, 3, 3)
>>> weights = torch.rand(4)
>>> l = _msgmsd(x, y, kernel, weights)
>>> l.size()
torch.Size([5])
"""

gmsds = []

for i in range(weights.numel()):
if i > 0:
x = F.avg_pool2d(x, kernel_size=2, ceil_mode=True)
y = F.avg_pool2d(y, kernel_size=2, ceil_mode=True)

gmsds.append(_gmsd(x, y, kernel, alpha=alpha, **kwargs))

msgmsd = torch.stack(gmsds, dim=-1) ** 2
msgmsd = torch.sqrt((msgmsd * weights).sum(dim=-1))

return msgmsd


def msgmsd(
x: torch.Tensor,
y: torch.Tensor,
kernel: torch.Tensor = None,
weights: torch.Tensor = None,
**kwargs,
) -> torch.Tensor:
r"""Returns the MS-GMSD between `x` and `y`.
Args:
x: An input tensor, (N, 3, H, W).
y: A target tensor, (N, 3, H, W).
kernel: A 2D gradient kernel, (2, 1, K, K).
If `None`, use the Prewitt kernel instead.
weights: The weights of the scales, (M,).
If `None`, use the official weights instead.
`**kwargs` are transmitted to `_msgmsd`.
Example:
>>> x = torch.rand(5, 3, 256, 256)
>>> y = torch.rand(5, 3, 256, 256)
>>> l = msgmsd(x, y)
>>> l.size()
torch.Size([5])
"""

# RGB to luminance
l_weights = _L_WEIGHTS.to(x.device).view(1, 3, 1, 1)

x = F.conv2d(x, l_weights)
y = F.conv2d(y, l_weights)

# Kernel
if kernel is None:
kernel = prewitt_kernel()
kernel = torch.stack([kernel, kernel.t()]).unsqueeze(1)
kernel = kernel.to(x.device)

# Weights
if weights is None:
weights = _MS_WEIGHTS.to(x.device)

return _msgmsd(x, y, kernel, weights, **kwargs)


class GMSD(nn.Module):
r"""Creates a criterion that measures the GMSD
between an input and a target.
Expand Down Expand Up @@ -173,3 +277,73 @@ def forward(
l = _gmsd(input, target, self.kernel, **self.kwargs)

return self.reduce(l)


class MSGMSD(nn.Module):
r"""Creates a criterion that measures the MSGMSD
between an input and a target.
Args:
kernel: A 2D gradient kernel, (2, 1, K, K).
If `None`, use the Prewitt kernel instead.
weights: The weights of the scales, (M,).
If `None`, use the official weights instead.
reduction: Specifies the reduction to apply to the output:
`'none'` | `'mean'` | `'sum'`.
`**kwargs` are transmitted to `_msgmsd`.
Shape:
* Input: (N, 3, H, W)
* Target: (N, 3, H, W)
* Output: (N,) or (1,) depending on `reduction`
Example:
>>> criterion = MSGMSD().cuda()
>>> x = torch.rand(5, 3, 256, 256).cuda()
>>> y = torch.rand(5, 3, 256, 256).cuda()
>>> l = criterion(x, y)
>>> l.size()
torch.Size([])
"""

def __init__(
self,
kernel: torch.Tensor = None,
weights: torch.Tensor = None,
reduction: str = 'mean',
**kwargs,
):
r""""""
super().__init__()

if kernel is None:
kernel = prewitt_kernel()
kernel = torch.stack([kernel, kernel.t()]).unsqueeze(1)

if weights is None:
weights = _MS_WEIGHTS

self.register_buffer('kernel', kernel)
self.register_buffer('weights', weights)
self.register_buffer('l_weights', _L_WEIGHTS.view(1, 3, 1, 1))

self.reduce = build_reduce(reduction)
self.kwargs = kwargs

def forward(
self,
input: torch.Tensor,
target: torch.Tensor,
) -> torch.Tensor:
r"""Defines the computation performed at every call.
"""

# RGB to luminance
input = F.conv2d(input, self.l_weights)
target = F.conv2d(target, self.l_weights)

# MSGMSD
l = _msgmsd(input, target, self.kernel, self.weights, **self.kwargs)

return self.reduce(l)
6 changes: 3 additions & 3 deletions piqa/ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

from typing import Tuple

_WEIGHTS = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333])
_MS_WEIGHTS = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333])


def create_window(window_size: int, n_channels: int) -> torch.Tensor:
Expand Down Expand Up @@ -209,7 +209,7 @@ def msssim(
window = create_window(window_size, n_channels).to(x.device)

if weights is None:
weights = _WEIGHTS.to(x.device)
weights = _MS_WEIGHTS.to(x.device)

return msssim_per_channel(x, y, window, weights, **kwargs).mean(-1)

Expand Down Expand Up @@ -313,7 +313,7 @@ def __init__(
super().__init__()

if weights is None:
weights = _WEIGHTS
weights = _MS_WEIGHTS

self.register_buffer('window', create_window(window_size, n_channels))
self.register_buffer('weights', weights)
Expand Down

0 comments on commit 2364fbc

Please sign in to comment.