From 2be72f6bd848eada823ae6f210fea1f06967111d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Rozet?= Date: Fri, 25 Sep 2020 13:58:04 +0200 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20Total=20Variation=20(TV)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- spiq/__init__.py | 1 + spiq/tv.py | 74 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 75 insertions(+) create mode 100644 spiq/tv.py diff --git a/spiq/__init__.py b/spiq/__init__.py index 2723380..b049e6d 100644 --- a/spiq/__init__.py +++ b/spiq/__init__.py @@ -2,3 +2,4 @@ from .psnr import psnr, PSNR from .ssim import ssim, msssim, SSIM, MSSSIM +from .tv import tv, TV diff --git a/spiq/tv.py b/spiq/tv.py new file mode 100644 index 0000000..9a7e93a --- /dev/null +++ b/spiq/tv.py @@ -0,0 +1,74 @@ +r"""Total Variation (TV) + +This module implements the TV in PyTorch. + +Wikipedia: + https://en.wikipedia.org/wiki/Total_variation +""" + +########### +# Imports # +########### + +import torch +import torch.nn as nn + + +############# +# Functions # +############# + +def tv(x: torch.Tensor, norm='L1') -> torch.Tensor: + r"""Returns the TV of `x`. + + Args: + x: input tensor, (..., C, H, W) + norm: norm to use ('L1', 'L2' or 'L2_squared') + """ + + w_var = x[..., :, 1:] - x[..., :, :-1] + h_var = x[..., 1:, :] - x[..., :-1, :] + + if norm in ['L2', 'L2_squared']: + w_var = w_var ** 2 + h_var = h_var ** 2 + else: # norm == 'L1' + w_var = w_var.abs() + h_var = h_var.abs() + + score = w_var.sum(dim=(-1, -2, -3)) + h_var.sum(dim=(-1, -2, -3)) + + if norm == 'L2': + score = torch.sqrt(score) + + return score + + +########### +# Classes # +########### + +class TV(nn.Module): + r"""Creates a criterion that measures the TV of an input. + """ + + def __init__(self, norm='L1', reduction='mean'): + super().__init__() + + self.norm = norm + self.reduction = reduction + + def forward(self, input: torch.Tensor) -> torch.Tensor: + r""" + Args: + input: input tensor, (N, C, H, W) + """ + + l = tv(input, norm=self.norm) + + if self.reduction == 'mean': + return l.mean() + elif self.reduction == 'sum': + return l.sum() + + return l