-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcommons.py
72 lines (57 loc) · 1.96 KB
/
commons.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import io
import PIL
import torch
import torch.nn as nn
from torchvision import models
from PIL import Image
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
def get_model():
checkpoint_path = 'classifier_resnet152.pth'
model = models.resnet152(pretrained=True)
classifier = nn.Sequential(nn.Linear(2048, 512),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(512, 38))
model.fc = classifier
model.load_state_dict(torch.load(checkpoint_path,map_location='cpu')['state_dict'], strict=False)
model.eval()
return model
def process_image(image):
''' Scales, crops, and normalizes a PIL image for a PyTorch model,
returns an Numpy array
'''
# Process a PIL image for use in a PyTorch model
size = 256, 256
image.resize(size, Image.ANTIALIAS)
image = image.crop((128 - 112, 128 - 112, 128 + 112, 128 + 112))
npImage = np.array(image)
npImage = npImage/255.
imgA = npImage[:,:,0]
imgB = npImage[:,:,1]
imgC = npImage[:,:,2]
imgA = (imgA - 0.485)/(0.229)
imgB = (imgB - 0.456)/(0.224)
imgC = (imgC - 0.406)/(0.225)
npImage[:,:,0] = imgA
npImage[:,:,1] = imgB
npImage[:,:,2] = imgC
npImage = np.transpose(npImage, (2,0,1))
print (npImage)
return npImage
def imshow(image, ax=None, title=None):
"""Imshow for Tensor."""
if ax is None:
fig, ax = plt.subplots()
# PyTorch tensors assume the color channel is the first dimension
# but matplotlib assumes is the third dimension
image = image.numpy().transpose((1, 2, 0))
# Undo preprocessing
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
image = std * image + mean
# Image needs to be clipped between 0 and 1 or it looks like noise when displayed
image = np.clip(image, 0, 1)
ax.imshow(image)
return ax