-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdemo_torchscript.py
97 lines (73 loc) · 2.67 KB
/
demo_torchscript.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import cv2
import numpy as np
import torch
from torch.utils.data import DataLoader
from ruamel.yaml import YAML
from dataloaders import KITTIRawLoader as KRL
torch.backends.cudnn.benchmark = True
torch.set_grad_enabled(False)
config = 'cfg_coex.yaml'
vid_date = "2011_09_26"
vid_num = '0093'
half_precision = True
def load_configs(path):
cfg = YAML().load(open(path, 'r'))
backbone_cfg = YAML().load(
open(cfg['model']['stereo']['backbone']['cfg_path'], 'r'))
cfg['model']['stereo']['backbone'].update(backbone_cfg)
return cfg
if __name__ == '__main__':
cfg = load_configs(
'./configs/stereo/{}'.format(config))
stereo = torch.jit.load('zoo/torchscript/CoEx.pt')
left_cam, right_cam = KRL.listfiles(
cfg,
vid_date,
vid_num,
True)
cfg['training']['th'] = 0
cfg['training']['tw'] = 0
kitti_train = KRL.ImageLoader(
left_cam, right_cam, cfg, training=True, demo=True)
kitti_train = DataLoader(
kitti_train, batch_size=1,
num_workers=4, shuffle=False, drop_last=False)
fps_list = np.array([])
stereo.eval()
for i, batch in enumerate(kitti_train):
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
imgL, imgR = batch['imgL'].cuda(), batch['imgR'].cuda()
imgLRaw = batch['imgLRaw']
imgLRaw = imgLRaw.cuda()
end.record()
torch.cuda.synchronize()
runtime = start.elapsed_time(end)
print('Data Preparation: {:.3f}'.format(runtime))
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
img = torch.cat((imgL, imgR), 1)
disp = stereo(img)
end.record()
torch.cuda.synchronize()
runtime = start.elapsed_time(end)
# print('Stereo runtime: {:.3f}'.format(runtime))
fps = 1000/runtime
fps_list = np.append(fps_list, fps)
if len(fps_list) > 5:
fps_list = fps_list[-5:]
avg_fps = np.mean(fps_list)
print('Stereo runtime: {:.3f}'.format(1000/avg_fps))
disp_np = (2*disp[0]).data.cpu().numpy().astype(np.uint8)
disp_np = cv2.applyColorMap(disp_np, cv2.COLORMAP_PLASMA)
image_np = (imgLRaw[0].permute(1, 2, 0).data.cpu().numpy()).astype(np.uint8)
out_img = np.concatenate((image_np, disp_np), 0)
cv2.putText(
out_img,
"%.1f fps" % (avg_fps),
(10, image_np.shape[0]+30),
cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA)
cv2.imshow('img', out_img)
cv2.waitKey(1)