diff --git a/piqa/__init__.py b/piqa/__init__.py index ebe79ba..2a56455 100644 --- a/piqa/__init__.py +++ b/piqa/__init__.py @@ -5,7 +5,7 @@ specific image quality assessement metric. """ -__version__ = '1.1.2' +__version__ = '1.1.3' from .tv import TV from .psnr import PSNR diff --git a/piqa/gmsd.py b/piqa/gmsd.py index 66236c4..ff98733 100644 --- a/piqa/gmsd.py +++ b/piqa/gmsd.py @@ -25,7 +25,6 @@ prewitt_kernel, gradient_kernel, channel_conv, - tensor_norm, ) @@ -77,8 +76,8 @@ def gmsd( # Gradient magnitude pad = kernel.size(-1) // 2 - gm_x = tensor_norm(channel_conv(x, kernel, padding=pad), dim=[1]) - gm_y = tensor_norm(channel_conv(y, kernel, padding=pad), dim=[1]) + gm_x = torch.linalg.norm(channel_conv(x, kernel, padding=pad), dim=1) + gm_y = torch.linalg.norm(channel_conv(y, kernel, padding=pad), dim=1) gm_xy = gm_x * gm_y diff --git a/piqa/lpips.py b/piqa/lpips.py index d1239d3..4225735 100644 --- a/piqa/lpips.py +++ b/piqa/lpips.py @@ -19,7 +19,6 @@ import torch.hub as hub from piqa.utils import _jit, _assert_type, _reduce -from piqa.utils.functional import normalize_tensor from typing import Dict, List @@ -225,8 +224,8 @@ def forward( residuals = [] for lin, fx, fy in zip(self.lins, self.net(input), self.net(target)): - fx = normalize_tensor(fx, dim=[1], norm='L2') - fy = normalize_tensor(fy, dim=[1], norm='L2') + fx = fx / torch.linalg.norm(fx, dim=1, keepdim=True) + fy = fy / torch.linalg.norm(fy, dim=1, keepdim=True) mse = ((fx - fy) ** 2).mean(dim=(-1, -2), keepdim=True) residuals.append(lin(mse).flatten()) diff --git a/piqa/mdsi.py b/piqa/mdsi.py index 3a19cdb..21c72c9 100644 --- a/piqa/mdsi.py +++ b/piqa/mdsi.py @@ -19,7 +19,6 @@ prewitt_kernel, gradient_kernel, channel_conv, - tensor_norm, ) import piqa.utils.complex as cx @@ -77,11 +76,11 @@ def mdsi( # Gradient magnitude pad = kernel.size(-1) // 2 - gm_x = tensor_norm(channel_conv(l_x, kernel, padding=pad), dim=[1]) - gm_y = tensor_norm(channel_conv(l_y, kernel, padding=pad), dim=[1]) - gm_avg = tensor_norm( + gm_x = torch.linalg.norm(channel_conv(l_x, kernel, padding=pad), dim=1) + gm_y = torch.linalg.norm(channel_conv(l_y, kernel, padding=pad), dim=1) + gm_avg = torch.linalg.norm( channel_conv((l_x + l_y) / 2., kernel, padding=pad), - dim=[1], + dim=1, ) gm_x_sq, gm_y_sq, gm_avg_sq = gm_x ** 2, gm_y ** 2, gm_avg ** 2 diff --git a/piqa/utils/__init__.py b/piqa/utils/__init__.py index 7e8f732..c84f458 100644 --- a/piqa/utils/__init__.py +++ b/piqa/utils/__init__.py @@ -85,10 +85,12 @@ def _assert_type( ) +@_jit def _reduce(x: torch.Tensor, reduction: str = 'mean') -> torch.Tensor: - r"""Returns a reducing module. + r"""Returns the reduction of \(x\). Args: + x: A tensor, \((*,)\). reduction: Specifies the reduction type: `'none'` | `'mean'` | `'sum'`. diff --git a/piqa/utils/complex.py b/piqa/utils/complex.py index e001a6e..b90e847 100644 --- a/piqa/utils/complex.py +++ b/piqa/utils/complex.py @@ -70,7 +70,7 @@ def mod(x: torch.Tensor, squared: bool = False) -> torch.Tensor: tensor([2.0000, 1.0000]) """ - x = (x ** 2).sum(dim=-1) + x = x.square().sum(dim=-1) if not squared: x = torch.sqrt(x) diff --git a/piqa/utils/functional.py b/piqa/utils/functional.py index e58e63b..8ab941c 100644 --- a/piqa/utils/functional.py +++ b/piqa/utils/functional.py @@ -82,10 +82,13 @@ def gaussian_kernel( ) -> torch.Tensor: r"""Returns the 1-dimensional Gaussian kernel of size \(K\). - $$ G(x) = \frac{1}{\sum_{y = 1}^{K} G(y)} \exp + $$ G(x) = \gamma \exp \left(\frac{(x - \mu)^2}{2 \sigma^2}\right) $$ - where \(x \in [1; K]\) is a position in the kernel + where \(\gamma\) is such that + + $$ \sum_{x = 1}^{K} G(x) = 1 $$ + and \(\mu = \frac{1 + K}{2}\). Args: @@ -263,124 +266,3 @@ def gradient_kernel(kernel: torch.Tensor) -> torch.Tensor: """ return torch.stack([kernel, kernel.t()]).unsqueeze(1) - - -def tensor_norm( - x: torch.Tensor, - dim: List[int], # Union[int, Tuple[int, ...]] = () - keepdim: bool = False, - norm: str = 'L2', -) -> torch.Tensor: - r"""Returns the norm of \(x\). - - $$ L_1(x) = \left\| x \right\|_1 = \sum_i \left| x_i \right| $$ - - $$ L_2(x) = \left\| x \right\|_2 = \left( \sum_i x^2_i \right)^\frac{1}{2} $$ - - Args: - x: A tensor, \((*,)\). - dim: The dimension(s) along which to calculate the norm. - keepdim: Whether the output tensor has `dim` retained or not. - norm: Specifies the norm funcion to apply: - `'L1'` | `'L2'` | `'L2_squared'`. - - Wikipedia: - https://en.wikipedia.org/wiki/Norm_(mathematics) - - Example: - >>> x = torch.arange(9).float().view(3, 3) - >>> x - tensor([[0., 1., 2.], - [3., 4., 5.], - [6., 7., 8.]]) - >>> tensor_norm(x, dim=0) - tensor([6.7082, 8.1240, 9.6437]) - """ - - if norm == 'L1': - x = x.abs() - else: # norm in ['L2', 'L2_squared'] - x = x ** 2 - - x = x.sum(dim=dim, keepdim=keepdim) - - if norm == 'L2': - x = x.sqrt() - - return x - - -def normalize_tensor( - x: torch.Tensor, - dim: List[int], # Union[int, Tuple[int, ...]] = () - norm: str = 'L2', - epsilon: float = 1e-8, -) -> torch.Tensor: - r"""Returns \(x\) normalized. - - $$ \hat{x} = \frac{x}{\left\|x\right\|} $$ - - Args: - x: A tensor, \((*,)\). - dim: The dimension(s) along which to normalize. - norm: Specifies the norm funcion to use: - `'L1'` | `'L2'` | `'L2_squared'`. - epsilon: A numerical stability term. - - Returns: - The normalized tensor, \((*,)\). - - Example: - >>> x = torch.arange(9, dtype=torch.float).view(3, 3) - >>> x - tensor([[0., 1., 2.], - [3., 4., 5.], - [6., 7., 8.]]) - >>> normalize_tensor(x, dim=0) - tensor([[0.0000, 0.1231, 0.2074], - [0.4472, 0.4924, 0.5185], - [0.8944, 0.8616, 0.8296]]) - """ - - norm = tensor_norm(x, dim=dim, keepdim=True, norm=norm) - - return x / (norm + epsilon) - - -def unravel_index( - indices: torch.LongTensor, - shape: List[int], -) -> torch.LongTensor: - r"""Converts flat indices into unraveled coordinates in a target shape. - - This is a `torch` implementation of `numpy.unravel_index`. - - Args: - indices: A tensor of (flat) indices, \((*, N)\). - shape: The targeted shape, \((D,)\). - - Returns: - The unraveled coordinates, \((*, N, D)\). - - Example: - >>> unravel_index(torch.arange(9), shape=(3, 3)) - tensor([[0, 0], - [0, 1], - [0, 2], - [1, 0], - [1, 1], - [1, 2], - [2, 0], - [2, 1], - [2, 2]]) - """ - - coord = [] - - for dim in reversed(shape): - coord.append(indices % dim) - indices = indices // dim - - coord = torch.stack(coord[::-1], dim=-1) - - return coord