Skip to content

Commit

Permalink
🚑 Fix FileNotFoundError for LPIPS weights
Browse files Browse the repository at this point in the history
  • Loading branch information
francois-rozet committed Dec 10, 2020
1 parent 3860c5c commit 200460a
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 18 deletions.
2 changes: 1 addition & 1 deletion piqa/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@
specific image quality assessement metric.
"""

__version__ = '1.0.1'
__version__ = '1.0.2'
74 changes: 57 additions & 17 deletions piqa/lpips.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,37 +16,80 @@
import torch
import torch.nn as nn
import torchvision.models as models
import torch.hub as hub

from piqa.utils import build_reduce, normalize_tensor, Intermediary

from typing import Dict

_SHIFT = torch.Tensor([0.485, 0.456, 0.406])
_SCALE = torch.Tensor([0.229, 0.224, 0.225])

_WEIGHTS_URL = (
'https://github.com/richzhang/PerceptualSimilarity'
'/raw/master/lpips/weights/{}/{}.pth'
)


def get_weights(
network: str = 'alex',
version: str = 'v0.1',
) -> Dict[str, torch.Tensor]:
r"""Returns the official LPIPS weights for `network`.
Args:
network: Specifies the perception network that is used:
`'alex'` | `'squeeze'` | `'vgg'`.
version: Specifies the official version release:
`'v0.0'` | `'v0.1'`.
"""

# Load from URL
weights = hub.load_state_dict_from_url(
_WEIGHTS_URL.format(version, network),
map_location='cpu',
)

# Format keys
weights = {
k.replace('lin', '').replace('.model', ''): v
for k, v in weights.items()
}

return weights


class LPIPS(nn.Module):
r"""Creates a criterion that measures the LPIPS
between an input and a target.
Args:
network: Specifies the perception network to use:
`'AlexNet'` | `'SqueezeNet'` | `'VGG16'`.
`'alex'` | `'squeeze'` | `'vgg'`.
scaling: Whether the input and target need to
be scaled w.r.t. ImageNet.
dropout: Whether dropout is used or not.
pretrained: Whether the official pretrained weights are used or not.
reduction: Specifies the reduction to apply to the output:
`'none'` | `'mean'` | `'sum'`.
Shape:
* Input: (N, 3, H, W)
* Target: (N, 3, H, W)
* Output: (N,) or (1,) depending on `reduction`
Note:
`LPIPS` is a *trainable* metric. To prevent the weights from updating,
use the `torch.no_grad()` context or freeze the weights.
"""

def __init__(
self,
network: str = 'AlexNet',
network: str = 'alex',
scaling: bool = True,
dropout: bool = True,
pretrained: bool = True,
reduction: str = 'mean',
trainable: bool = False,
):
r""""""
super().__init__()
Expand All @@ -57,39 +100,36 @@ def __init__(
self.register_buffer('scale', _SCALE.view(1, -1, 1, 1))

# Perception layers
if network == 'AlexNet':
if network == 'alex': # AlexNet
layers = models.alexnet(pretrained=True).features
targets = [1, 4, 7, 9, 11]
channels = [64, 192, 384, 256, 256]
elif network == 'SqueezeNet':
elif network == 'squeeze': # SqueezeNet
layers = models.squeezenet1_1(pretrained=True).features
targets = [1, 4, 7, 9, 10, 11, 12]
channels = [64, 128, 256, 384, 384, 512, 512]
elif network == 'VGG16':
elif network == 'vgg': # VGG16
layers = models.vgg16(pretrained=True).features
targets = [3, 8, 15, 22, 29]
channels = [64, 128, 256, 512, 512]
else:
raise ValueError('Unknown network architecture ' + network)

self.net = Intermediary(layers, targets)
for p in self.net.parameters():
p.requires_grad = False

# Linear comparators
state_path = os.path.join(
os.path.dirname(inspect.getsourcefile(self.__init__)),
f'weights/lpips_{network}.pth'
)

self.lin = nn.ModuleList([
nn.Conv2d(c, 1, kernel_size=1, stride=1, padding=0, bias=False)
nn.Sequential(
nn.Dropout(inplace=True) if dropout else nn.Identity(),
nn.Conv2d(c, 1, kernel_size=1, stride=1, padding=0, bias=False),
)
for c in channels
])
self.lin.load_state_dict(torch.load(state_path))

if not trainable:
# Prevent gradients
for p in self.parameters():
p.requires_grad = False
if pretrained:
self.lin.load_state_dict(get_weights(network=network))

self.reduce = build_reduce(reduction)

Expand Down
Binary file removed piqa/weights/lpips_AlexNet.pth
Binary file not shown.
Binary file removed piqa/weights/lpips_SqueezeNet.pth
Binary file not shown.
Binary file removed piqa/weights/lpips_VGG16.pth
Binary file not shown.

0 comments on commit 200460a

Please sign in to comment.