From 0125e904ff74c1339798a8b46b003d27184ff199 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Rozet?= Date: Wed, 9 Jun 2021 22:23:51 +0200 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20Visual=20Saliency-based=20Index=20(?= =?UTF-8?q?VSI)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ✨ Feature Similarity (FSIM) ⬆️ Upgrade to PyTorch 1.8.0 (for FFT) ♻️ Refactor color space conversion tools 📝 Add credits to official implementations --- README.md | 4 +- piqa/__init__.py | 4 +- piqa/fsim.py | 304 +++++++++++++++++++++++++++++++++++++++ piqa/gmsd.py | 9 +- piqa/haarpsi.py | 4 +- piqa/mdsi.py | 7 +- piqa/ssim.py | 4 +- piqa/utils/color.py | 149 ++++++++++++++----- piqa/utils/complex.py | 58 ++++++++ piqa/utils/functional.py | 69 +++++++++ piqa/vsi.py | 259 +++++++++++++++++++++++++++++++++ requirements.txt | 4 +- tests/benchmark.py | 25 ++-- tests/doctests.py | 4 + 14 files changed, 845 insertions(+), 59 deletions(-) create mode 100644 piqa/fsim.py create mode 100644 piqa/vsi.py diff --git a/README.md b/README.md index d947259..a7c5c94 100644 --- a/README.md +++ b/README.md @@ -87,12 +87,14 @@ l = ssim(x, y, kernel=kernel, channel_avg=False) | TV | `TV` | `[0, ∞]` | / | 1937 | [Total Variation](https://en.wikipedia.org/wiki/Total_variation) | | PSNR | `PSNR` | `[0, ∞]` | max | / | [Peak Signal-to-Noise Ratio](https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio) | | SSIM | `SSIM` | `[0, 1]` | max | 2004 | [Structural Similarity](https://en.wikipedia.org/wiki/Structural_similarity) | -| MS-SSIM | `MS_SSIM` | `[0, 1]` | max | 2004 | [Multi-Scale Structural Similarity](https://ieeexplore.ieee.org/abstract/document/1292216/) | +| MS-SSIM | `MS_SSIM` | `[0, 1]` | max | 2004 | [Multi-Scale Structural Similarity](https://ieeexplore.ieee.org/document/1292216/) | | LPIPS | `LPIPS` | `[0, ∞]` | min | 2018 | [Learned Perceptual Image Patch Similarity](https://arxiv.org/abs/1801.03924) | | GMSD | `GMSD` | `[0, ∞]` | min | 2013 | [Gradient Magnitude Similarity Deviation](https://arxiv.org/abs/1308.3052) | | MS-GMSD | `MS_GMSD` | `[0, ∞]` | min | 2017 | [Multi-Scale Gradient Magnitude Similarity Deviation](https://ieeexplore.ieee.org/document/7952357) | | MDSI | `MDSI` | `[0, ∞]` | min | 2016 | [Mean Deviation Similarity Index](https://arxiv.org/abs/1608.07433) | | HaarPSI | `HaarPSI` | `[0, 1]` | max | 2018 | [Haar Perceptual Similarity Index](https://arxiv.org/abs/1607.06140) | +| VSI | `VSI` | `[0, 1]` | max | 2014 | [Visual Saliency-based Index](https://ieeexplore.ieee.org/document/6873260) | +| FSIM | `FSIM` | `[0, 1]` | max | 2011 | [Feature Similarity](https://ieeexplore.ieee.org/document/5705575) | ### JIT diff --git a/piqa/__init__.py b/piqa/__init__.py index 71017ad..0ed5c24 100644 --- a/piqa/__init__.py +++ b/piqa/__init__.py @@ -5,7 +5,7 @@ specific image quality assessement metric. """ -__version__ = '1.1.4' +__version__ = '1.1.5' from .tv import TV from .psnr import PSNR @@ -14,3 +14,5 @@ from .gmsd import GMSD, MS_GMSD from .mdsi import MDSI from .haarpsi import HaarPSI +from .vsi import VSI +from .fsim import FSIM diff --git a/piqa/fsim.py b/piqa/fsim.py new file mode 100644 index 0000000..e0835cd --- /dev/null +++ b/piqa/fsim.py @@ -0,0 +1,304 @@ +"""Feature Similarity (FSIM) + +This module implements the FSIM in PyTorch. + +Credits: + Inspired by the [official implementation](https://www4.comp.polyu.edu.hk/~cslzhang/IQA/FSIM/FSIM.htm) + +References: + [1] FSIM: A Feature Similarity Index for Image Quality Assessment + (Zhang et al., 2011) + https://ieeexplore.ieee.org/document/5705575 + + [2] Image Features From Phase Congruency + (Kovesi, 1999) + https://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.4.1641 +""" + +import math +import torch +import torch.fft as fft +import torch.nn as nn +import torch.nn.functional as F + +from piqa.utils import _jit, _assert_type, _reduce +from piqa.utils.color import ColorConv +from piqa.utils.functional import ( + scharr_kernel, + gradient_kernel, + filter_grid, + log_gabor, + channel_conv, +) + +import piqa.utils.complex as cx + + +@_jit +def phase_congruency( + x: torch.Tensor, + value_range: float = 1., + scales: int = 4, + orientations: int = 4, + wavelength: float = 6., + factor: float = 2., + sigma_f: float = 0.5978, # -log(0.55) + sigma_theta: float = 0.6545, # pi / (4 * 1.2) + k: float = 2., + rescale: float = 1.7, + eps: float = 1e-8, +) -> torch.Tensor: + r"""Returns the phase congruency of \(x\). + + Args: + x: An input tensor, \((N, 1, H, W)\). + + For the remaining arguments, refer to [2]. + + Returns: + The PC tensor, \((N, H, W)\). + + Example: + >>> x = torch.rand(5, 1, 256, 256) + >>> l = phase_congruency(x) + >>> l.size() + torch.Size([5, 256, 256]) + """ + + x = x * (255. / value_range) + + # log-Gabor filters + r, theta = filter_grid(x) # (H, W) + + ## Radial + lowpass = 1 / (1 + (r / 0.45) ** (2 * 15)) + + a = torch.stack([ + log_gabor(r, 1 / (wavelength * factor ** i), sigma_f) * lowpass + for i in range(scales) + ]) + + ## Angular + cos_theta = torch.cos(theta) + sin_theta = torch.sin(theta) + + theta_j = math.pi * torch.arange(orientations).to(x) / orientations + theta_j = theta_j.view(orientations, 1, 1) + + # Measure (theta - theta_j) in the sine/cosine domains + # to prevent wrap-around errors + delta_sin = sin_theta * theta_j.cos() - cos_theta * theta_j.sin() + delta_cos = cos_theta * theta_j.cos() + sin_theta * theta_j.sin() + delta_theta = torch.atan2(delta_sin, delta_cos) + + b = torch.exp(-delta_theta ** 2 / (2 * sigma_theta ** 2)) + + ## Combine + filters = a[:, None] * b[None, :] + + # Even & odd (real and imaginary) filter responses + eo = fft.ifft2(fft.fft2(x[:, None]) * filters) + eo = torch.view_as_real(eo) # (N, scales, orientations, H, W, 2) + + ## Amplitude + a = cx.mod(eo) + + ## Energy + sum_eo = eo.sum(dim=1, keepdim=True) + mean_eo = sum_eo / (cx.mod(sum_eo)[..., None] + eps) + + rot90_eo = cx.complex(-cx.imag(eo), cx.real(eo)) + + energy = cx.dot(eo, mean_eo) - cx.dot(rot90_eo, mean_eo).abs() + energy = energy.sum(dim=1, keepdim=True) + energy = energy.squeeze(1) # (N, orientations, H, W) + + # Noise + e2 = a[:, 0] ** 2 + median_e2, _ = torch.median(e2.flatten(-2), dim=-1) + mean_e2 = -median_e2 / math.log(0.5) + + em = (filters[0] ** 2).sum(dim=(-1, -2)) + noise_power = mean_e2 / em + + ## Total energy^2 due to noise + ifft_filters = fft.ifft2(filters) + ifft_filters = cx.real(torch.view_as_real(ifft_filters)) + + sum_aiaj = (ifft_filters[None, :] * ifft_filters[:, None]).sum(dim=(0, 1, 3, 4)) + sum_aiaj = sum_aiaj * r.numel() + + noise_energy2 = noise_power * sum_aiaj # (N, orientations) + noise_energy2 = noise_energy2[..., None, None] + + ## Noise threshold + tau = noise_energy2.sqrt() # Rayleigh parameter + + c, d = (math.pi / 2) ** 0.5, (2 - math.pi / 2) ** 0.5 + noise_threshold = tau * (c + k * d) + noise_threshold = noise_threshold / rescale # emprirical rescaling + + energy = (energy - noise_threshold).relu() + + # Phase congruency + pc = energy.sum(dim=1) / (a.sum(dim=(1, 2)) + eps) # (N, H, W) + + return pc + + +@_jit +def fsim( + x: torch.Tensor, + y: torch.Tensor, + kernel: torch.Tensor, + value_range: float = 1., + t1: float = 0.85, + t2: float = 160. / (255. ** 2), + t3: float = 200. / (255. ** 2), + t4: float = 200. / (255. ** 2), + lmbda: float = 0.03, +) -> torch.Tensor: + r"""Returns the FSIM between \(x\) and \(y\), + without color space conversion and downsampling. + + Args: + x: An input tensor, \((N, 3 \text{ or } 1, H, W)\). + y: A target tensor, \((N, 3 \text{ or } 1, H, W)\). + kernel: A gradient kernel, \((2, 1, K, K)\). + value_range: The value range \(L\) of the inputs (usually 1. or 255). + + For the remaining arguments, refer to [1]. + + Returns: + The FSIM vector, \((N,)\). + + Example: + >>> x = torch.rand(5, 3, 256, 256) + >>> y = torch.rand(5, 3, 256, 256) + >>> kernel = gradient_kernel(scharr_kernel()) + >>> l = fsim(x, y, kernel) + >>> l.size() + torch.Size([5]) + """ + + t2 *= value_range ** 2 + t3 *= value_range ** 2 + t4 *= value_range ** 2 + + y_x, y_y = x[:, :1], y[:, :1] + + # Phase congruency similarity + pc_x = phase_congruency(y_x, value_range) + pc_y = phase_congruency(y_y, value_range) + pc_m = torch.max(pc_x, pc_y) + + s_pc = (2 * pc_x * pc_y + t1) / (pc_x ** 2 + pc_y ** 2 + t1) + + # Gradient magnitude similarity + pad = kernel.size(-1) // 2 + + g_x = torch.linalg.norm(channel_conv(y_x, kernel, padding=pad), dim=1) + g_y = torch.linalg.norm(channel_conv(y_y, kernel, padding=pad), dim=1) + + s_g = (2 * g_x * g_y + t2) / (g_x ** 2 + g_y ** 2 + t2) + + # Chrominance similarity + s_l = s_pc * s_g + + if x.size(1) == 3: + i_x, i_y = x[:, 1], y[:, 1] + q_x, q_y = x[:, 2], y[:, 2] + + s_i = (2 * i_x * i_y + t3) / (i_x ** 2 + i_y ** 2 + t3) + s_q = (2 * q_x * q_y + t4) / (q_x ** 2 + q_y ** 2 + t4) + + s_iq = s_i * s_q + s_iq = cx.complex(s_iq, torch.zeros_like(s_iq)) + s_iq_lambda = cx.real(cx.pow(s_iq, lmbda)) + + s_l = s_l * s_iq_lambda + + # Feature similarity + fs = (s_l * pc_m).sum(dim=(-1, -2)) / pc_m.sum(dim=(-1, -2)) + + return fs + + +class FSIM(nn.Module): + r"""Creates a criterion that measures the FSIM + between an input and a target. + + Before applying `fsim`, the input and target are converted from + RBG to Y(IQ) and downsampled by a factor \( \frac{\min(H, W)}{256} \). + + Args: + chromatic: Whether to use the chromatic channels (IQ) or not. + kernel: A gradient kernel, \((2, 1, K, K)\). + If `None`, use the Scharr kernel instead. + reduction: Specifies the reduction to apply to the output: + `'none'` | `'mean'` | `'sum'`. + + `**kwargs` are transmitted to `fsim`. + + Example: + >>> criterion = FSIM().cuda() + >>> x = torch.rand(5, 3, 256, 256, requires_grad=True).cuda() + >>> y = torch.rand(5, 3, 256, 256).cuda() + >>> l = 1 - criterion(x, y) + >>> l.size() + torch.Size([]) + >>> l.backward() + """ + + def __init__( + self, + chromatic: bool = True, + kernel: torch.Tensor = None, + reduction: str = 'mean', + **kwargs, + ): + r"""""" + super().__init__() + + if kernel is None: + kernel = gradient_kernel(scharr_kernel()) + + self.register_buffer('kernel', kernel) + + self.convert = ColorConv('RGB', 'YIQ' if chromatic else 'Y') + self.reduction = reduction + self.value_range = kwargs.get('value_range', 1.) + self.kwargs = kwargs + + def forward( + self, + input: torch.Tensor, + target: torch.Tensor, + ) -> torch.Tensor: + r"""Defines the computation performed at every call. + """ + + _assert_type( + [input, target], + device=self.kernel.device, + dim_range=(4, 4), + n_channels=3, + value_range=(0., self.value_range), + ) + + # Downsample + _, _, h, w = input.size() + M = round(min(h, w) / 256) + + if M > 1: + input = F.avg_pool2d(input, kernel_size=M, ceil_mode=True) + target = F.avg_pool2d(target, kernel_size=M, ceil_mode=True) + + # RGB to Y(IQ) + input = self.convert(input) + target = self.convert(target) + + # FSIM + l = fsim(input, target, kernel=self.kernel, **self.kwargs) + + return _reduce(l, self.reduction) diff --git a/piqa/gmsd.py b/piqa/gmsd.py index 77f0202..30c75e9 100644 --- a/piqa/gmsd.py +++ b/piqa/gmsd.py @@ -3,6 +3,9 @@ This module implements the GMSD and MS-GMSD in PyTorch. +Credits: + Inspired by the [official implementation](https://www4.comp.polyu.edu.hk/~cslzhang/IQA/GMSD/GMSD.htm) + References: [1] Gradient Magnitude Similarity Deviation: An Highly Efficient Perceptual Image Quality Index @@ -20,7 +23,7 @@ import torch.nn.functional as F from piqa.utils import _jit, _assert_type, _reduce -from piqa.utils.color import get_conv +from piqa.utils.color import ColorConv from piqa.utils.functional import ( prewitt_kernel, gradient_kernel, @@ -200,7 +203,7 @@ def __init__( self.register_buffer('kernel', kernel) - self.convert = get_conv('RGB', 'Y') + self.convert = ColorConv('RGB', 'Y') self.reduction = reduction self.value_range = kwargs.get('value_range', 1.) self.kwargs = kwargs @@ -290,7 +293,7 @@ def __init__( self.register_buffer('kernel', kernel) self.register_buffer('weights', weights) - self.convert = get_conv('RGB', 'Y') + self.convert = ColorConv('RGB', 'Y') self.reduction = reduction self.value_range = kwargs.get('value_range', 1.) self.kwargs = kwargs diff --git a/piqa/haarpsi.py b/piqa/haarpsi.py index 2a78789..90ebcc6 100644 --- a/piqa/haarpsi.py +++ b/piqa/haarpsi.py @@ -20,7 +20,7 @@ import torch.nn.functional as F from piqa.utils import _jit, _assert_type, _reduce -from piqa.utils.color import get_conv +from piqa.utils.color import ColorConv from piqa.utils.functional import ( haar_kernel, gradient_kernel, @@ -155,7 +155,7 @@ def __init__( r"""""" super().__init__() - self.convert = get_conv('RGB', 'YIQ' if chromatic else 'Y') + self.convert = ColorConv('RGB', 'YIQ' if chromatic else 'Y') self.reduction = reduction self.value_range = kwargs.get('value_range', 1.) self.kwargs = kwargs diff --git a/piqa/mdsi.py b/piqa/mdsi.py index 977f971..02235ae 100644 --- a/piqa/mdsi.py +++ b/piqa/mdsi.py @@ -2,6 +2,9 @@ This module implements the MDSI in PyTorch. +Credits: + Inspired by the [official implementation](https://www.mathworks.com/matlabcentral/fileexchange/59809-mdsi-ref-dist-combmethod) + References: [1] Mean Deviation Similarity Index: Efficient and Reliable Full-Reference Image Quality Evaluator @@ -14,7 +17,7 @@ import torch.nn.functional as F from piqa.utils import _jit, _assert_type, _reduce -from piqa.utils.color import get_conv +from piqa.utils.color import ColorConv from piqa.utils.functional import ( prewitt_kernel, gradient_kernel, @@ -159,7 +162,7 @@ def __init__( self.register_buffer('kernel', kernel) - self.convert = get_conv('RGB', 'LHM') + self.convert = ColorConv('RGB', 'LHM') self.reduction = reduction self.value_range = kwargs.get('value_range', 1.) self.kwargs = kwargs diff --git a/piqa/ssim.py b/piqa/ssim.py index 5634ad1..2278d42 100644 --- a/piqa/ssim.py +++ b/piqa/ssim.py @@ -11,11 +11,11 @@ References: [1] Image quality assessment: From error visibility to structural similarity (Wang et al., 2004) - https://ieeexplore.ieee.org/abstract/document/1284395/ + https://ieeexplore.ieee.org/document/1284395/ [2] Multiscale structural similarity for image quality assessment (Wang et al., 2004) - https://ieeexplore.ieee.org/abstract/document/1292216/ + https://ieeexplore.ieee.org/document/1292216/ """ import torch diff --git a/piqa/utils/color.py b/piqa/utils/color.py index 070b0e8..c2a949e 100644 --- a/piqa/utils/color.py +++ b/piqa/utils/color.py @@ -5,62 +5,143 @@ import torch.nn as nn import torch.nn.functional as F +from typing import Tuple + + +def spatial(x: torch.Tensor) -> int: + r"""Returns the number of spatial dimensions of \(x\).""" + + return len(x.shape) - 2 + + +def color_conv( + x: torch.Tensor, + weight: torch.Tensor, +) -> torch.Tensor: + r"""Returns the color convolution of \(x\) with the kernel `weight`. + + Args: + x: A tensor, \((N, C, *)\). + weight: A weight kernel, \((C', C)\). + """ + + return F.conv1d(x, weight.view(weight.shape + (1,) * spatial(x))) + + +RGB_TO_YIQ = torch.tensor([ + [0.299, 0.587, 0.114], + [0.5969, -0.2746, -0.3213], + [0.2115, -0.5227, 0.3112], +]) + +RGB_TO_LHM = torch.tensor([ + [0.2989, 0.5870, 0.1140], + [0.3, 0.04, -0.35], + [0.34, -0.6, 0.17], +]) + +RGB_TO_LMN = torch.tensor([ + [0.06, 0.63, 0.27], + [0.30, 0.04, -0.35], + [0.34, -0.6, 0.17], +]) + +RGB_TO_LMN = torch.tensor([ + [0.06, 0.63, 0.27], + [0.30, 0.04, -0.35], + [0.34, -0.6, 0.17], +]) + _WEIGHTS = { - ('RGB', 'YIQ'): torch.tensor([ # HaarPSI - [0.299, 0.587, 0.114], - [0.5969, -0.2746, -0.3213], - [0.2115, -0.5227, 0.3112], - ]), - ('RGB', 'Y'): torch.tensor([ # GMSD - [0.299, 0.587, 0.114], - ]), - ('RGB', 'LHM'): torch.tensor([ # MDSI - [0.2989, 0.5870, 0.1140], - [0.3, 0.04, -0.35], - [0.34, -0.6, 0.17], - ]), + ('RGB', 'YIQ'): RGB_TO_YIQ, # HaarPSI + ('RGB', 'Y'): RGB_TO_YIQ[0:1], # GMSD + ('RGB', 'LHM'): RGB_TO_LHM, # MDSI + ('RGB', 'LMN'): RGB_TO_LMN, # VSI } -def get_conv(src: str, dst: str, dim: int = 2) -> nn.Module: - r"""Returns a color conversion module. +class ColorConv(nn.Module): + r"""Color convolution module. Args: src: The source color space. E.g. `RGB`. dst: The destination color space. E.g. `YIQ`. - dim: The number of space dimensions. E.g. 2 for 2D images. - - Returns: - The color conversion module. Example: >>> x = torch.rand(5, 3, 256, 256) - >>> conv = get_conv('RGB', 'YIQ', dim=2) + >>> conv = ColorConv('RGB', 'YIQ') >>> y = conv(x) >>> y.size() torch.Size([5, 3, 256, 256]) """ - assert (src, dst) in _WEIGHTS, f'Unknown {src} to {dst} conversion' + def __init__(self, src: str, dst: str): + super().__init__() + + assert (src, dst) in _WEIGHTS, f'Unknown {src} to {dst} conversion' + + self.register_buffer('weight', _WEIGHTS[(src, dst)]) + + @property + def device(self) -> torch.device: + return self.weight.device + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return color_conv(x, self.weight) + + +def rgb_to_xyz(x: torch.Tensor, value_range: float = 1.) -> torch.Tensor: + r"""Converts from sRGB to (CIE) XYZ. + + References: + https://en.wikipedia.org/wiki/SRGB + """ + + x = x / value_range - weight = _WEIGHTS[(src, dst)] - weight = weight.view(weight.size() + (1,) * dim) + mask = x <= 0.04045 + left = x / 12.92 + right = ((x + 0.055) / 1.055) ** 2.4 - return _ColorConv(weight) + x = torch.where(mask, left, right) + weight = torch.tensor([ + [0.4124564, 0.3575761, 0.1804375], + [0.2126729, 0.7151522, 0.0721750], + [0.0193339, 0.1191920, 0.9503041], + ]) -class _ColorConv(nn.Module): - r"""Color Conversion/Convolution module""" + return color_conv(x, weight.to(x)) - def __init__(self, weight: torch.Tensor): - super().__init__() - self.register_buffer('weight', weight) +def xyz_to_lab( + x: torch.Tensor, + illuminants: Tuple[float, float, float] = (0.956797052643698, 1., 0.9214805860173273), +) -> torch.Tensor: + r"""Converts from (CIE) XYZ to (CIE) LAB. - @property - def device(self): - return self.weight.device + References: + https://en.wikipedia.org/wiki/CIELAB_color_space + """ - def forward(self, x: torch.Tensor) -> torch.Tensor: - return F.conv1d(x, self.weight) + scale = torch.tensor(illuminants).view((3,) + (1,) * spatial(x)) + x = x / scale.to(x) + + delta = 6 / 29 + + mask = x > delta ** 3 + left = x ** (1 / 3) + right = x / (3 * delta ** 2) + 4 / 29 + + x = torch.where(mask, left, right) + + weight = torch.tensor([ + [0., 116., 0.], + [500., -500., 0.], + [0., 200., -200.], + ]) + + bias = torch.tensor([-16., 0., 0.]).view(scale.shape) + + return color_conv(x, weight.to(x)) + bias.to(x) diff --git a/piqa/utils/complex.py b/piqa/utils/complex.py index b90e847..325b53f 100644 --- a/piqa/utils/complex.py +++ b/piqa/utils/complex.py @@ -28,6 +28,42 @@ def complex(real: torch.Tensor, imag: torch.Tensor) -> torch.Tensor: return torch.stack([real, imag], dim=-1) +def real(x: torch.Tensor) -> torch.Tensor: + r"""Returns the real part of \(x\). + + Args: + x: A complex tensor, \((*, 2)\). + + Returns: + The real tensor, \((*,)\). + + Example: + >>> x = torch.tensor([[2., 0.], [0.7071, 0.7071]]) + >>> real(x) + tensor([2.0000, 0.7071]) + """ + + return x[..., 0] + + +def imag(x: torch.Tensor) -> torch.Tensor: + r"""Returns the imaginary part of \(x\). + + Args: + x: A complex tensor, \((*, 2)\). + + Returns: + The imaginary tensor, \((*,)\). + + Example: + >>> x = torch.tensor([[2., 0.], [0.7071, 0.7071]]) + >>> imag(x) + tensor([0.0000, 0.7071]) + """ + + return x[..., 1] + + def polar(r: torch.Tensor, phi: torch.Tensor) -> torch.Tensor: r"""Returns a complex tensor with its modulus equal to \(r\) and its phase equal to \(\phi\). @@ -125,6 +161,28 @@ def prod(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return complex(x_r * y_r - x_i * y_i, x_i * y_r + x_r * y_i) +def dot(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + r"""Returns the element-wise dot-product of \(x\) and \(y\). + + $$ x \odot y = \Re(x) \Re(y) + \Im(x) \Im(y) $$ + + Args: + x: A complex tensor, \((*, 2)\). + y: A complex tensor, \((*, 2)\). + + Returns: + The dot-product tensor, \((*,)\). + + Example: + >>> x = torch.tensor([[2., 0.], [0.7071, 0.7071]]) + >>> y = torch.tensor([[2., -0.], [0.7071, -0.7071]]) + >>> dot(x, y) + tensor([4., 0.]) + """ + + return (x * y).sum(dim=-1) + + def pow(x: torch.Tensor, exponent: float) -> torch.Tensor: r"""Returns the power of \(x\) with `exponent`. diff --git a/piqa/utils/functional.py b/piqa/utils/functional.py index 8ab941c..de99119 100644 --- a/piqa/utils/functional.py +++ b/piqa/utils/functional.py @@ -2,6 +2,7 @@ """ import torch +import torch.fft as fft import torch.nn as nn import torch.nn.functional as F @@ -266,3 +267,71 @@ def gradient_kernel(kernel: torch.Tensor) -> torch.Tensor: """ return torch.stack([kernel, kernel.t()]).unsqueeze(1) + + +def filter_grid(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + r"""Returns the (quadrant-shifted) frequency grid for \(x\). + + Args: + x: An input tensor, \((*, H, W)\). + + Returns: + The radius and phase tensors, both \((H, W)\). + + Example: + >>> x = torch.rand(5, 5) + >>> r, phi = filter_grid(x) + >>> r + tensor([[0.0000, 0.2500, 0.5000, 0.5000, 0.2500], + [0.2500, 0.3536, 0.5590, 0.5590, 0.3536], + [0.5000, 0.5590, 0.7071, 0.7071, 0.5590], + [0.5000, 0.5590, 0.7071, 0.7071, 0.5590], + [0.2500, 0.3536, 0.5590, 0.5590, 0.3536]]) + >>> phi + tensor([[-0.0000, -1.5708, -1.5708, 1.5708, 1.5708], + [-0.0000, -0.7854, -1.1071, 1.1071, 0.7854], + [-0.0000, -0.4636, -0.7854, 0.7854, 0.4636], + [-3.1416, -2.6779, -2.3562, 2.3562, 2.6779], + [-3.1416, -2.3562, -2.0344, 2.0344, 2.3562]]) + """ + + u, v = [ + (torch.arange(n).to(x) - n // 2) / (n - n % 2) + for n in x.shape[-2:] + ] + u, v = fft.ifftshift(u[:, None]), fft.ifftshift(v[None, :]) + + r = (u ** 2 + v ** 2).sqrt() + phi = torch.atan2(-v, u) + + return r, phi + + +def log_gabor(f: torch.Tensor, f_0: float, sigma_f: float) -> torch.Tensor: + r"""Returns the log-Gabor filter of \(f\). + + $$ G(f) = \exp \left( - \frac{\log(f / f_0)^2}{2 \sigma_f^2} \right) $$ + + Args: + f: A frequency tensor, \((*,)\). + f_0: The center frequency \(f_0\). + sigma_f: The bandwidth (log-)deviation \(\sigma_f\). + + Returns: + The filter tensor, \((*,)\). + + Wikipedia: + https://en.wikipedia.org/wiki/Log_Gabor_filter + + Example: + >>> x = torch.rand(5, 5) + >>> r, phi = filter_grid(x) + >>> log_gabor(r, 1., 1.) + tensor([[0.0000, 0.3825, 0.7864, 0.7864, 0.3825], + [0.3825, 0.5825, 0.8444, 0.8444, 0.5825], + [0.7864, 0.8444, 0.9417, 0.9417, 0.8444], + [0.7864, 0.8444, 0.9417, 0.9417, 0.8444], + [0.3825, 0.5825, 0.8444, 0.8444, 0.5825]]) + """ + + return torch.exp(- (f / f_0).log() ** 2 / (2 * sigma_f ** 2)) diff --git a/piqa/vsi.py b/piqa/vsi.py new file mode 100644 index 0000000..008936c --- /dev/null +++ b/piqa/vsi.py @@ -0,0 +1,259 @@ +r"""Visual Saliency-based Index (VSI) + +This module implements the VSI in PyTorch. + +Wikipedia: + https://en.wikipedia.org/wiki/Salience_(neuroscience)#Visual_saliency_modeling + +Credits: + Inspired by the [official implementation](https://sse.tongji.edu.cn/linzhang/IQA/VSI/VSI.htm) + +References: + [1] VSI: A Visual Saliency-Induced Index for Perceptual Image Quality Assessment + (Zhang et al., 2014) + https://ieeexplore.ieee.org/document/6873260 + + [2] SDSP: A novel saliency detection method by combining simple priors + (Zhang et al., 2013) + https://ieeexplore.ieee.org/document/6738036 +""" + +import torch +import torch.fft as fft +import torch.nn as nn +import torch.nn.functional as F + +from piqa.utils import _jit, _assert_type, _reduce +from piqa.utils.color import ColorConv, rgb_to_xyz, xyz_to_lab +from piqa.utils.functional import ( + scharr_kernel, + gradient_kernel, + filter_grid, + log_gabor, + channel_conv, +) + +import piqa.utils.complex as cx + + +@_jit +def vsi( + x: torch.Tensor, + y: torch.Tensor, + vs_x: torch.Tensor, + vs_y: torch.Tensor, + kernel: torch.Tensor, + value_range: float = 1., + c1: float = 1.27, + c2: float = 386. / (255. ** 2), + c3: float = 130. / (255. ** 2), + alpha: float = 0.4, + beta: float = 0.02, +) -> torch.Tensor: + r"""Returns the VSI between \(x\) and \(y\), + without downsampling and color space conversion. + + Args: + x: An input tensor, \((N, 3, H, W)\). + y: A target tensor, \((N, 3, H, W)\). + vs_x: The input visual salience, \((N, H, W)\). + vs_y: The target visual salience, \((N, H, W)\). + kernel: A gradient kernel, \((2, 1, K, K)\). + value_range: The value range \(L\) of the inputs (usually 1. or 255). + + For the remaining arguments, refer to [1]. + + Returns: + The VSI vector, \((N,)\). + + Example: + >>> x = torch.rand(5, 3, 256, 256) + >>> y = torch.rand(5, 3, 256, 256) + >>> vs_x, vs_y = sdsp(x), sdsp(y) + >>> kernel = gradient_kernel(scharr_kernel()) + >>> l = vsi(x, y, vs_x, vs_y, kernel) + >>> l.size() + torch.Size([5]) + """ + + c2 *= value_range ** 2 + c3 *= value_range ** 2 + + l_x, mn_x = x[:, :1], x[:, 1:] + l_y, mn_y = y[:, :1], y[:, 1:] + + # Visual saliency similarity + vs_m = torch.max(vs_x, vs_y) + s_vs = (2 * vs_x * vs_y + c1) / (vs_x ** 2 + vs_y ** 2 + c1) + + # Gradient magnitude similarity + pad = kernel.size(-1) // 2 + + g_x = torch.linalg.norm(channel_conv(l_x, kernel, padding=pad), dim=1) + g_y = torch.linalg.norm(channel_conv(l_y, kernel, padding=pad), dim=1) + + s_g = (2 * g_x * g_y + c2) / (g_x ** 2 + g_y ** 2 + c2) + + # Chorminance similarity + s_c = (2 * mn_x * mn_y + c3) / (mn_x ** 2 + mn_y ** 2 + c3) + s_c = s_c.prod(dim=1) + + s_c = cx.complex(s_c, torch.zeros_like(s_c)) + s_c_beta = cx.real(cx.pow(s_c, beta)) + + # Visual Saliency-based Index + s = s_vs * s_g ** alpha * s_c_beta + vsi = (s * vs_m).sum(dim=(-1, -2)) / vs_m.sum(dim=(-1, -2)) + + return vsi + + +@_jit +def sdsp( + x: torch.Tensor, + value_range: float = 1., + omega_0: float = 0.021, + sigma_f: float = 1.34, + sigma_c: float = 0.001, + sigma_d: float = 145., +) -> torch.Tensor: + r"""Detects salient regions from \(x\). + + Args: + x: An input tensor, \((N, 3, H, W)\). + value_range: The value range \(L\) of the input (usually 1. or 255). + + For the remaining arguments, refer to [2]. + + Returns: + The visual saliency tensor, \((N, H, W)\). + + Example: + >>> x = torch.rand(5, 3, 256, 256) + >>> l = sdsp(x) + >>> l.size() + torch.Size([5, 256, 256]) + """ + + x_lab = xyz_to_lab(rgb_to_xyz(x, value_range)) + + # Frequency prior + w, _ = filter_grid(x_lab) + lg = log_gabor(w, omega_0, sigma_f) + lg = lg * (w <= 0.5) # low-pass + + x_f = fft.ifft2(fft.fft2(x_lab) * lg) + x_f = cx.real(torch.view_as_real(x_f)) + + s_f = torch.linalg.norm(x_f, dim=1) + + # Color prior + x_ab = x_lab[:, 1:] + + lo, _ = x_ab.view(x_ab.shape[:2] + (-1,)).min(dim=-1) + up, _ = x_ab.view(x_ab.shape[:2] + (-1,)).max(dim=-1) + + lo = lo.view(x_ab.shape[:2] + (1, 1)) + up = up.view(lo.shape) + span = torch.where(up > lo, up - lo, torch.tensor(1.).to(lo)) + + x_ab = (x_ab - lo) / span + + s_c = 1. - torch.exp(-torch.sum(x_ab ** 2, dim=1) / sigma_c ** 2) + + # Location prior + a, b = [ + torch.arange(n).to(x) - (n - 1) / 2 + for n in x.shape[2:] + ] + + s_d = torch.exp(-(a[None, :] ** 2 + b[:, None] ** 2) / sigma_d ** 2) + + # Visual saliency + vs = s_f * s_c * s_d + + return vs + + +class VSI(nn.Module): + r"""Creates a criterion that measures the VSI + between an input and a target. + + Before applying `vsi`, the input and target are converted from + RBG to LMN and downsampled by a factor \( \frac{\min(H, W)}{256} \). + + The visual saliency maps of the input and target are determined by `sdsp`. + + Args: + kernel: A gradient kernel, \((2, 1, K, K)\). + If `None`, use the Scharr kernel instead. + reduction: Specifies the reduction to apply to the output: + `'none'` | `'mean'` | `'sum'`. + + `**kwargs` are transmitted to `vsi`. + + Example: + >>> criterion = VSI().cuda() + >>> x = torch.rand(5, 3, 256, 256, requires_grad=True).cuda() + >>> y = torch.rand(5, 3, 256, 256).cuda() + >>> l = 1 - criterion(x, y) + >>> l.size() + torch.Size([]) + >>> l.backward() + """ + + def __init__( + self, + kernel: torch.Tensor = None, + reduction: str = 'mean', + **kwargs, + ): + r"""""" + super().__init__() + + if kernel is None: + kernel = gradient_kernel(scharr_kernel()) + + self.register_buffer('kernel', kernel) + + self.convert = ColorConv('RGB', 'LMN') + self.reduction = reduction + self.value_range = kwargs.get('value_range', 1.) + self.kwargs = kwargs + + def forward( + self, + input: torch.Tensor, + target: torch.Tensor, + ) -> torch.Tensor: + r"""Defines the computation performed at every call. + """ + + _assert_type( + [input, target], + device=self.kernel.device, + dim_range=(4, 4), + n_channels=3, + value_range=(0., self.value_range), + ) + + # Downsample + _, _, h, w = input.size() + M = round(min(h, w) / 256) + + if M > 1: + input = F.avg_pool2d(input, kernel_size=M, ceil_mode=True) + target = F.avg_pool2d(target, kernel_size=M, ceil_mode=True) + + # Visual saliancy + vs_input = sdsp(input, self.value_range) + vs_target = sdsp(target, self.value_range) + + # RGB to LMN + input = self.convert(input) + target = self.convert(target) + + # VSI + l = vsi(input, target, vs_input, vs_target, kernel=self.kernel, **self.kwargs) + + return _reduce(l, self.reduction) diff --git a/requirements.txt b/requirements.txt index df940a6..8c13589 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,2 @@ -torch>=1.7.0 -torchvision>=0.8.0 +torch>=1.8.0 +torchvision>=0.9.0 diff --git a/tests/benchmark.py b/tests/benchmark.py index f8f43a1..34470b1 100644 --- a/tests/benchmark.py +++ b/tests/benchmark.py @@ -39,7 +39,6 @@ 'TV': (1, { 'kornia.tv': kornia.total_variation, 'piq.tv': lambda x: piq.total_variation(x, norm_type='l1'), - 'piq.TV': piq.TVLoss(norm_type='l1'), 'piqa.TV': piqa.TV(), }), 'PSNR': (2, { @@ -55,49 +54,50 @@ multichannel=True, gaussian_weights=True, ), - 'piq.ssim': piq.ssim, + 'piq.ssim': lambda x, y: piq.ssim(x, y, downsample=False), 'kornia.SSIM-halfloss': kornia.SSIM( window_size=11, reduction='mean', ), - 'piq.SSIM-loss': piq.SSIMLoss(), 'IQA.SSIM-loss': IQA.SSIM(), 'vainf.SSIM': vainf.SSIM(data_range=1.), 'piqa.SSIM': piqa.SSIM(), }), 'MS-SSIM': (2, { 'piq.ms_ssim': piq.multi_scale_ssim, - 'piq.MS_SSIM-loss': piq.MultiScaleSSIMLoss(), 'IQA.MS_SSIM-loss': IQA.MS_SSIM(), 'vainf.MS_SSIM': vainf.MS_SSIM(data_range=1.), 'piqa.MS_SSIM': piqa.MS_SSIM(), }), 'LPIPS': (2, { - 'piq.LPIPS': piq.LPIPS(), - # 'IQA.LPIPS': IQA.LPIPSvgg(), + # 'piq.LPIPS': piq.LPIPS(), + 'IQA.LPIPS': IQA.LPIPSvgg(), 'piqa.LPIPS': piqa.LPIPS(network='vgg') }), 'GMSD': (2, { 'piq.gmsd': piq.gmsd, - 'piq.GMSD': piq.GMSDLoss(), - # 'IQA.GMSD': IQA.GMSD(), 'piqa.GMSD': piqa.GMSD(), }), 'MS-GMSD': (2, { 'piq.ms_gmsd': piq.multi_scale_gmsd, - 'piq.MS_GMSD': piq.MultiScaleGMSDLoss(), 'piqa.MS_GMSD': piqa.MS_GMSD(), }), 'MDSI': (2, { 'piq.mdsi': piq.mdsi, - 'piq.MDSI-loss': piq.MDSILoss(), 'piqa.MDSI': piqa.MDSI(), }), 'HaarPSI': (2, { 'piq.haarpsi': piq.haarpsi, - 'piq.HaarPSI-loss': piq.HaarPSILoss(), 'piqa.HaarPSI': piqa.HaarPSI(), }), + 'VSI': (2, { + 'piq.vsi': piq.vsi, + 'piqa.VSI': piqa.VSI(), + }), + 'FSIM': (2, { + 'piq.fsim': piq.fsim, + 'piqa.FSIM': piqa.FSIM(), + }), } @@ -159,7 +159,6 @@ def main( y = totensor(truth).repeat(batch, 1, 1, 1).to(device) x.requires_grad_() - y.requires_grad_() # Metrics if metrics: @@ -170,6 +169,8 @@ def main( else: metrics = {k: v for (k, v) in METRICS.items()} + del metrics['LPIPS'] + # Benchmark for name, (nargs, methods) in metrics.items(): print(name) diff --git a/tests/doctests.py b/tests/doctests.py index d1e4371..6877439 100644 --- a/tests/doctests.py +++ b/tests/doctests.py @@ -17,6 +17,8 @@ mdsi, gmsd, haarpsi, + vsi, + fsim, ) @@ -34,6 +36,8 @@ mdsi, gmsd, haarpsi, + vsi, + fsim, ] for m in modules: