-
-
Notifications
You must be signed in to change notification settings - Fork 223
/
Copy pathprocess_image.py
76 lines (56 loc) · 2.01 KB
/
process_image.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
#!/usr/bin/env python
# coding: utf-8
import os,sys
sys.path.insert(0,"..")
from glob import glob
import matplotlib.pyplot as plt
import numpy as np
import argparse
import skimage, skimage.io
import pprint
import torch
import torch.nn.functional as F
import torchvision, torchvision.transforms
import torchxrayvision as xrv
parser = argparse.ArgumentParser()
parser.add_argument('-f', type=str, default="", help='')
parser.add_argument('img_path', type=str)
parser.add_argument('-weights', type=str,default="densenet121-res224-all")
parser.add_argument('-feats', default=False, help='', action='store_true')
parser.add_argument('-cuda', default=False, help='', action='store_true')
parser.add_argument('-resize', default=False, help='', action='store_true')
cfg = parser.parse_args()
img = skimage.io.imread(cfg.img_path)
img = xrv.datasets.normalize(img, 255)
# Check that images are 2D arrays
if len(img.shape) > 2:
img = img[:, :, 0]
if len(img.shape) < 2:
print("error, dimension lower than 2 for image")
# Add color channel
img = img[None, :, :]
# the models will resize the input to the correct size so this is optional.
if cfg.resize:
transform = torchvision.transforms.Compose([xrv.datasets.XRayCenterCrop(),
xrv.datasets.XRayResizer(224)])
else:
transform = torchvision.transforms.Compose([xrv.datasets.XRayCenterCrop()])
img = transform(img)
model = xrv.models.get_model(cfg.weights)
output = {}
with torch.no_grad():
img = torch.from_numpy(img).unsqueeze(0)
if cfg.cuda:
img = img.cuda()
model = model.cuda()
if cfg.feats:
feats = model.features(img)
feats = F.relu(feats, inplace=True)
feats = F.adaptive_avg_pool2d(feats, (1, 1))
output["feats"] = list(feats.cpu().detach().numpy().reshape(-1))
preds = model(img).cpu()
output["preds"] = dict(zip(xrv.datasets.default_pathologies,preds[0].detach().numpy()))
if cfg.feats:
print(output)
else:
pprint.pprint(output)