diff --git a/piqa/__init__.py b/piqa/__init__.py index b714aaa..a23caf9 100644 --- a/piqa/__init__.py +++ b/piqa/__init__.py @@ -5,4 +5,4 @@ specific image quality assessement metric. """ -__version__ = '1.0.2' +__version__ = '1.0.4' diff --git a/piqa/tv.py b/piqa/tv.py index 1e368ac..de2b991 100644 --- a/piqa/tv.py +++ b/piqa/tv.py @@ -21,14 +21,22 @@ def tv(x: torch.Tensor, norm: str = 'L2') -> torch.Tensor: `'L1'` | `'L2'` | `'L2_squared'`. """ - variation = torch.cat([ - x[..., :, 1:] - x[..., :, :-1], - x[..., 1:, :] - x[..., :-1, :], - ], dim=-2) + w_var = x[..., :, 1:] - x[..., :, :-1] + h_var = x[..., 1:, :] - x[..., :-1, :] - tv = tensor_norm(variation, dim=(-1, -2, -3), norm=norm) + if norm in ['L2', 'L2_squared']: + w_var = w_var ** 2 + h_var = h_var ** 2 + else: # norm == 'L1' + w_var = w_var.abs() + h_var = h_var.abs() - return tv + var = w_var.sum(dim=(-1, -2, -3)) + h_var.sum(dim=(-1, -2, -3)) + + if norm == 'L2': + var = torch.sqrt(var) + + return var class TV(nn.Module): diff --git a/piqa/utils.py b/piqa/utils.py index 8d8bb8c..a9782c6 100644 --- a/piqa/utils.py +++ b/piqa/utils.py @@ -153,19 +153,21 @@ def tensor_norm( def normalize_tensor( x: torch.Tensor, + dim: Tuple[int, ...] = (), + norm: str = 'L2', epsilon: float = 1e-8, - **kwargs, ) -> torch.Tensor: r"""Returns `x` normalized. Args: x: An input tensor. + dim: The dimension(s) along which to normalize. + norm: Specifies the norm funcion to use: + `'L1'` | `'L2'` | `'L2_squared'`. epsilon: A numerical stability term. - - `**kwargs` are transmitted to `tensor_norm`. """ - norm = tensor_norm(x, **kwargs) + norm = tensor_norm(x, dim=dim, keepdim=True, norm=norm) return x / (norm + epsilon)