Skip to content

Commit

Permalink
🚸 Provide option to disable debugging (#17)
Browse files Browse the repository at this point in the history
🚸 Improve type error messages
  • Loading branch information
francois-rozet committed Dec 24, 2021
1 parent d098f4d commit ffc07e0
Show file tree
Hide file tree
Showing 13 changed files with 68 additions and 56 deletions.
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -129,3 +129,10 @@ If you need the absolute best performances, the assertions can be disabled with
```bash
python -O your_awesome_code_using_piqa.py
```

Alternatively, you can disable PIQA's type assertions within your code with

```python
from piqa.utils import set_debug
set_debug(False)
```
6 changes: 3 additions & 3 deletions piqa/fsim.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import torch.nn as nn
import torch.nn.functional as F

from piqa.utils import _jit, _assert_type, _reduce
from piqa.utils import _jit, assert_type, reduce_tensor
from piqa.utils.color import ColorConv
from piqa.utils.functional import (
scharr_kernel,
Expand Down Expand Up @@ -308,7 +308,7 @@ def forward(
r"""Defines the computation performed at every call.
"""

_assert_type(
assert_type(
[input, target],
device=self.kernel.device,
dim_range=(4, 4),
Expand Down Expand Up @@ -339,4 +339,4 @@ def forward(
# FSIM
l = fsim(input, target, pc_input, pc_target, kernel=self.kernel, **self.kwargs)

return _reduce(l, self.reduction)
return reduce_tensor(l, self.reduction)
10 changes: 5 additions & 5 deletions piqa/gmsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import torch.nn as nn
import torch.nn.functional as F

from piqa.utils import _jit, _assert_type, _reduce
from piqa.utils import _jit, assert_type, reduce_tensor
from piqa.utils.color import ColorConv
from piqa.utils.functional import (
prewitt_kernel,
Expand Down Expand Up @@ -219,7 +219,7 @@ def forward(
r"""Defines the computation performed at every call.
"""

_assert_type(
assert_type(
[input, target],
device=self.kernel.device,
dim_range=(4, 4),
Expand All @@ -239,7 +239,7 @@ def forward(
# GMSD
l = gmsd(input, target, kernel=self.kernel, **self.kwargs)

return _reduce(l, self.reduction)
return reduce_tensor(l, self.reduction)


class MS_GMSD(nn.Module):
Expand Down Expand Up @@ -310,7 +310,7 @@ def forward(
r"""Defines the computation performed at every call.
"""

_assert_type(
assert_type(
[input, target],
device=self.kernel.device,
dim_range=(4, 4),
Expand All @@ -331,4 +331,4 @@ def forward(
**self.kwargs,
)

return _reduce(l, self.reduction)
return reduce_tensor(l, self.reduction)
6 changes: 3 additions & 3 deletions piqa/haarpsi.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import torch.nn as nn
import torch.nn.functional as F

from piqa.utils import _jit, _assert_type, _reduce
from piqa.utils import _jit, assert_type, reduce_tensor
from piqa.utils.color import ColorConv
from piqa.utils.functional import (
haar_kernel,
Expand Down Expand Up @@ -171,7 +171,7 @@ def forward(
r"""Defines the computation performed at every call.
"""

_assert_type(
assert_type(
[input, target],
device=self.convert.device,
dim_range=(4, 4),
Expand All @@ -191,4 +191,4 @@ def forward(
# HaarPSI
l = haarpsi(input, target, **self.kwargs)

return _reduce(l, self.reduction)
return reduce_tensor(l, self.reduction)
6 changes: 3 additions & 3 deletions piqa/lpips.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import torchvision.models as models
import torch.hub as hub

from piqa.utils import _jit, _assert_type, _reduce
from piqa.utils import _jit, assert_type, reduce_tensor

from typing import Dict, List

Expand Down Expand Up @@ -207,7 +207,7 @@ def forward(
r"""Defines the computation performed at every call.
"""

_assert_type(
assert_type(
[input, target],
device=self.shift.device,
dim_range=(4, 4),
Expand All @@ -232,4 +232,4 @@ def forward(

l = torch.stack(residuals, dim=-1).sum(dim=-1)

return _reduce(l, self.reduction)
return reduce_tensor(l, self.reduction)
6 changes: 3 additions & 3 deletions piqa/mdsi.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import torch.nn as nn
import torch.nn.functional as F

from piqa.utils import _jit, _assert_type, _reduce
from piqa.utils import _jit, assert_type, reduce_tensor
from piqa.utils.color import ColorConv
from piqa.utils.functional import (
prewitt_kernel,
Expand Down Expand Up @@ -178,7 +178,7 @@ def forward(
r"""Defines the computation performed at every call.
"""

_assert_type(
assert_type(
[input, target],
device=self.kernel.device,
dim_range=(4, 4),
Expand All @@ -202,4 +202,4 @@ def forward(
# MDSI
l = mdsi(input, target, kernel=self.kernel, **self.kwargs)

return _reduce(l, self.reduction)
return reduce_tensor(l, self.reduction)
6 changes: 3 additions & 3 deletions piqa/psnr.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch
import torch.nn as nn

from piqa.utils import _jit, _assert_type, _reduce
from piqa.utils import _jit, assert_type, reduce_tensor


@_jit
Expand Down Expand Up @@ -109,7 +109,7 @@ def forward(
r"""Defines the computation performed at every call.
"""

_assert_type(
assert_type(
[input, target],
device=input.device,
dim_range=(1, -1),
Expand All @@ -118,4 +118,4 @@ def forward(

l = psnr(input, target, **self.kwargs)

return _reduce(l, self.reduction)
return reduce_tensor(l, self.reduction)
10 changes: 5 additions & 5 deletions piqa/ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import torch.nn as nn
import torch.nn.functional as F

from piqa.utils import _jit, _assert_type, _reduce
from piqa.utils import _jit, assert_type, reduce_tensor
from piqa.utils.functional import (
gaussian_kernel,
kernel_views,
Expand Down Expand Up @@ -251,7 +251,7 @@ def forward(
r"""Defines the computation performed at every call.
"""

_assert_type(
assert_type(
[input, target],
device=self.kernel.device,
dim_range=(3, -1),
Expand All @@ -261,7 +261,7 @@ def forward(

l = ssim(input, target, kernel=self.kernel, **self.kwargs)[0]

return _reduce(l, self.reduction)
return reduce_tensor(l, self.reduction)


class MS_SSIM(nn.Module):
Expand Down Expand Up @@ -331,7 +331,7 @@ def forward(
r"""Defines the computation performed at every call.
"""

_assert_type(
assert_type(
[input, target],
device=self.kernel.device,
dim_range=(4, 4),
Expand All @@ -347,4 +347,4 @@ def forward(
**self.kwargs,
)

return _reduce(l, self.reduction)
return reduce_tensor(l, self.reduction)
6 changes: 3 additions & 3 deletions piqa/tv.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch
import torch.nn as nn

from piqa.utils import _jit, _assert_type, _reduce
from piqa.utils import _jit, assert_type, reduce_tensor


@_jit
Expand Down Expand Up @@ -94,8 +94,8 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
r"""Defines the computation performed at every call.
"""

_assert_type([input], device=input.device, dim_range=(3, -1))
assert_type([input], device=input.device, dim_range=(3, -1))

l = tv(input, **self.kwargs)

return _reduce(l, self.reduction)
return reduce_tensor(l, self.reduction)
43 changes: 27 additions & 16 deletions piqa/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,25 @@
_jit = lambda f: f


def _debug(mode: bool = __debug__) -> bool:
r"""Returns whether debugging is enabled or not.
__piqa_debug__ = __debug__

def set_debug(mode: bool = False) -> bool:
r"""Sets and returns whether debugging is enabled or not.
If `__debug__` is `False`, this function has not effect.
Example:
>>> set_debug(False)
False
"""

return mode
global __piqa_debug__

__piqa_debug__ = __debug__ and mode

return __piqa_debug__


def _assert_type(
def assert_type(
tensors: List[torch.Tensor],
device: torch.device,
dim_range: Tuple[int, int] = (0, -1),
Expand All @@ -33,60 +44,60 @@ def _assert_type(
Example:
>>> x = torch.rand(5, 3, 256, 256)
>>> y = torch.rand(5, 3, 256, 256)
>>> _assert_type([x, y], device=x.device, dim_range=(4, 4), n_channels=3)
>>> assert_type([x, y], device=x.device, dim_range=(4, 4), n_channels=3)
"""

if not _debug():
if not __piqa_debug__:
return

ref = tensors[0]

for t in tensors:
assert t.device == device, (
f'Expected tensors to be on {device}, got {t.device}'
f'Tensors expected to be on {device}, got {t.device}'
)

assert t.shape == ref.shape, (
'Expected tensors to be of the same shape, got'
'Tensors expected to be of the same shape, got'
f' {ref.shape} and {t.shape}'
)

if dim_range[0] == dim_range[1]:
assert t.dim() == dim_range[0], (
'Expected number of dimensions to be'
'Number of dimensions expected to be'
f' {dim_range[0]}, got {t.dim()}'
)
elif dim_range[0] < dim_range[1]:
assert dim_range[0] <= t.dim() <= dim_range[1], (
'Expected number of dimensions to be between'
'Number of dimensions expected to be between'
f' {dim_range[0]} and {dim_range[1]}, got {t.dim()}'
)
elif dim_range[0] > 0:
assert dim_range[0] <= t.dim(), (
'Expected number of dimensions to be greater or equal to'
'Number of dimensions expected to be greater or equal to'
f' {dim_range[0]}, got {t.dim()}'
)

if n_channels > 0:
assert t.size(1) == n_channels, (
'Expected number of channels to be'
'Number of channels expected to be'
f' {n_channels}, got {t.size(1)}'
)

if value_range[0] < value_range[1]:
assert value_range[0] <= t.min(), (
'Expected values to be greater or equal to'
'Values expected to be greater or equal to'
f' {value_range[0]}, got {t.min()}'
)

assert t.max() <= value_range[1], (
'Expected values to be lower or equal to'
'Values expected to be lower or equal to'
f' {value_range[1]}, got {t.max()}'
)


@_jit
def _reduce(x: torch.Tensor, reduction: str = 'mean') -> torch.Tensor:
def reduce_tensor(x: torch.Tensor, reduction: str = 'mean') -> torch.Tensor:
r"""Returns the reduction of \(x\).
Args:
Expand All @@ -96,7 +107,7 @@ def _reduce(x: torch.Tensor, reduction: str = 'mean') -> torch.Tensor:
Example:
>>> x = torch.arange(5)
>>> _reduce(x, reduction='sum')
>>> reduce_tensor(x, reduction='sum')
tensor(10)
"""

Expand Down
3 changes: 3 additions & 0 deletions piqa/utils/color.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ def device(self) -> torch.device:
return self.weight.device

def forward(self, x: torch.Tensor) -> torch.Tensor:
r"""Defines the computation performed at every call.
"""

return color_conv(x, self.weight)


Expand Down
6 changes: 3 additions & 3 deletions piqa/vsi.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import torch.nn as nn
import torch.nn.functional as F

from piqa.utils import _jit, _assert_type, _reduce
from piqa.utils import _jit, assert_type, reduce_tensor
from piqa.utils.color import ColorConv, rgb_to_xyz, xyz_to_lab
from piqa.utils.functional import (
scharr_kernel,
Expand Down Expand Up @@ -261,7 +261,7 @@ def forward(
r"""Defines the computation performed at every call.
"""

_assert_type(
assert_type(
[input, target],
device=self.kernel.device,
dim_range=(4, 4),
Expand Down Expand Up @@ -292,4 +292,4 @@ def forward(
# VSI
l = vsi(input, target, vs_input, vs_target, kernel=self.kernel, **self.kwargs)

return _reduce(l, self.reduction)
return reduce_tensor(l, self.reduction)
Loading

0 comments on commit ffc07e0

Please sign in to comment.