Skip to content

Commit

Permalink
✨ Structural Similarity (SSIM)
Browse files Browse the repository at this point in the history
  • Loading branch information
francois-rozet committed Oct 23, 2020
1 parent d79e5f3 commit c1addd2
Show file tree
Hide file tree
Showing 3 changed files with 179 additions and 1 deletion.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,6 @@ import spiq
x = torch.rand(3, 3, 256, 256)
y = torch.rand(3, 3, 256, 256)

l = spiq.psnr(x, y)
a = spiq.psnr(x, y)
b = spiq.ssim(x, y)
```
1 change: 1 addition & 0 deletions spiq/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
__version__ = '0.0.1'

from .psnr import psnr, PSNR
from .ssim import ssim, SSIM
176 changes: 176 additions & 0 deletions spiq/ssim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
r"""Structural Similarity (SSIM)
This module implements the SSIM in PyTorch.
Wikipedia:
https://en.wikipedia.org/wiki/Structural_similarity
Credits:
Inspired by pytorch-msssim
https://github.com/VainF/pytorch-msssim
References:
[1] Multiscale structural similarity for image quality assessment
(Wang et al., 2003)
https://ieeexplore.ieee.org/abstract/document/1292216/
[2] Image quality assessment: From error visibility to structural similarity
(Wang et al., 2004)
https://ieeexplore.ieee.org/abstract/document/1284395/
"""

###########
# Imports #
###########

import torch
import torch.nn as nn
import torch.nn.functional as F


#############
# Constants #
#############

_SIGMA = 1.5
_K1, _K2 = 0.01, 0.03


#############
# Functions #
#############

def gaussian_kernel(kernel_size: int, sigma: float=1., n: int=2) -> torch.Tensor:
r"""Returns the `n`-dimensional Gaussian kernel of size `kernel_size`.
The distribution is centered around the kernel's center and the standard deviation is `sigma`.
Args:
kernel_size: size of the kernel
sigma: standard deviation of the distribution
n: number of dimensions of the kernel
Wikipedia:
https://en.wikipedia.org/wiki/Normal_distribution
"""

distrib = torch.arange(kernel_size).float()
distrib -= (kernel_size - 1) / 2
distrib = distrib ** 2

kernel = distrib.clone()

for i in range(1, n):
distrib = distrib.unsqueeze(0)
kernel = kernel.unsqueeze(i)
kernel = kernel + distrib

kernel = torch.exp(-kernel / (2 * sigma ** 2))
kernel /= kernel.sum()

return kernel


def create_window(window_size: int, n_channels: int) -> torch.Tensor:
r"""Returns the SSIM convolution window of size `window_size`.
Args:
window_size: size of the window
n_channels: number of channels
"""

kernel = gaussian_kernel(window_size, _SIGMA)

window = kernel.unsqueeze(0).unsqueeze(0)
window = window.expand(n_channels, 1, window_size, window_size)

return window


def ssim_per_channel(x: torch.Tensor, y: torch.Tensor, window: torch.Tensor, value_range: float=1.) -> torch.Tensor:
r"""Returns the SSIM and the contrast sensitivity (CS) per channel between `x` and `y`.
Args:
x: input tensor, (N, C, H, W)
y: target tensor, (N, C, H, W)
window: convolution window
value_range: value range of the inputs (usually 1. or 255)
"""

n_channels, _, window_size, _ = window.size()
padding = window_size // 2

mu_x = F.conv2d(x, window, padding=padding, groups=n_channels)
mu_y = F.conv2d(y, window, padding=padding, groups=n_channels)

mu_x_sq = mu_x ** 2
mu_y_sq = mu_y ** 2
mu_xy = mu_x * mu_y

sigma_x_sq = F.conv2d(x ** 2, window, padding=padding, groups=n_channels) - mu_x_sq
sigma_y_sq = F.conv2d(y ** 2, window, padding=padding, groups=n_channels) - mu_y_sq
sigma_xy = F.conv2d(x * y, window, padding=padding, groups=n_channels) - mu_xy

c1 = (_K1 * value_range) ** 2
c2 = (_K2 * value_range) ** 2

cs_map = (2. * sigma_xy + c2) / (sigma_x_sq + sigma_y_sq + c2)
ssim_map = (2. * mu_x * mu_y + c1) / (mu_x_sq + mu_y_sq + c1) * cs_map

return ssim_map.mean((-1, -2)), cs_map.mean((-1, -2))


def ssim(x: torch.Tensor, y: torch.Tensor, window_size: int=11, value_range: float=1.) -> torch.Tensor:
r"""Returns the SSIM between `x` and `y`.
Args:
x: input tensor of shape, (N, C, H, W)
y: target tensor of shape, (N, C, H, W)
window_size: size of the window
value_range: value range of the inputs (usually 1. or 255)
"""

n_channels = x.size(1)
window = create_window(window_size, n_channels).to(x.device)

return ssim_per_channel(x, y, window, value_range)[0].mean(-1)


###########
# Classes #
###########

class SSIM(nn.Module):
r"""Creates a criterion that measures the SSIM between an input and a target.
"""

def __init__(self, window_size: int=11, n_channels: int=3, value_range: float=1., reduction='mean'):
super().__init__()

self.register_buffer(
'window',
create_window(window_size, n_channels)
)

self.value_range = value_range
self.reduction = reduction

def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
r"""
Args:
input: input tensor, (N, C, H, W)
target: target tensor, (N, C, H, W)
"""

l = ssim_per_channel(
input, target,
window=self.window,
value_range=self.value_range
)[0].mean(-1)

if self.reduction == 'mean':
return l.mean()
elif self.reduction == 'sum':
return l.sum()

return l

0 comments on commit c1addd2

Please sign in to comment.