forked from milesial/Pytorch-UNet
-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathto-ts.py
executable file
·57 lines (45 loc) · 1.61 KB
/
to-ts.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
#!/usr/bin/env python
import argparse
import os
import numpy as np
import torch
import torchvision
from unet import UNet
from uresnet import UResNet
from nestedunet import NestedUNet
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--model', '-m', default='MODEL.pth',
metavar='FILE',
help="Specify the file in which is stored the model"
" (default : 'MODEL.pth')")
parser.add_argument('--gpu', '-g', action='store_true',
help="Use cuda version of the net",
default=False)
return parser.parse_args()
def count_params(net):
model_parameters = filter(lambda p: p.requires_grad, net.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
print('params = ', params)
if __name__ == "__main__":
args = get_args()
input_channels = 3
output_channels = 1
net = UNet(input_channels, output_channels)
# net = UResNet(input_channels, output_channels)
# net = NestedUNet(input_channels, output_channels)
count_params(net)
example = torch.rand(1, input_channels, 800, 600)
if args.gpu:
net.cuda()
net.load_state_dict(torch.load(args.model))
sm = torch.jit.trace(net, example.cuda())
output = net(example.cuda())
# print(output[0][0][0])
else:
net.cpu()
net.load_state_dict(torch.load(args.model, map_location='cpu'))
sm = torch.jit.trace(net, example)
output = net(example)
# print(output[0][0][0])
sm.save('ts-model.ts')