This repository has been archived by the owner on Apr 1, 2024. It is now read-only.
forked from princeton-vl/DPVO
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdemo.py
110 lines (82 loc) · 3.6 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import cv2
import numpy as np
import glob
import os.path as osp
import os
import torch
from pathlib import Path
from multiprocessing import Process, Queue
from plyfile import PlyElement, PlyData
from dpvo.utils import Timer
from dpvo.dpvo import DPVO
from dpvo.config import cfg
from dpvo.stream import image_stream, video_stream
from dpvo.plot_utils import plot_trajectory, save_trajectory_tum_format
SKIP = 0
def show_image(image, t=0):
image = image.permute(1, 2, 0).cpu().numpy()
cv2.imshow('image', image / 255.0)
cv2.waitKey(t)
@torch.no_grad()
def run(cfg, network, imagedir, calib, stride=1, skip=0, viz=False, timeit=False, save_reconstruction=False):
slam = None
queue = Queue(maxsize=8)
if os.path.isdir(imagedir):
reader = Process(target=image_stream, args=(queue, imagedir, calib, stride, skip))
else:
reader = Process(target=video_stream, args=(queue, imagedir, calib, stride, skip))
reader.start()
while 1:
(t, image, intrinsics) = queue.get()
if t < 0: break
image = torch.from_numpy(image).permute(2,0,1).cuda()
intrinsics = torch.from_numpy(intrinsics).cuda()
if slam is None:
slam = DPVO(cfg, network, ht=image.shape[1], wd=image.shape[2], viz=viz)
image = image.cuda()
intrinsics = intrinsics.cuda()
with Timer("SLAM", enabled=timeit):
slam(t, image, intrinsics)
for _ in range(12):
slam.update()
reader.join()
if save_reconstruction:
points = slam.points_.cpu().numpy()[:slam.m]
colors = slam.colors_.view(-1, 3).cpu().numpy()[:slam.m]
points = np.array([(x,y,z,r,g,b) for (x,y,z),(r,g,b) in zip(points, colors)],
dtype=[('x', '<f4'), ('y', '<f4'), ('z', '<f4'),('red', 'u1'), ('green', 'u1'),('blue', 'u1')])
el = PlyElement.describe(points, 'vertex',{'some_property': 'f8'},{'some_property': 'u4'})
return slam.terminate(), PlyData([el], text=True)
return slam.terminate()
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--network', type=str, default='dpvo.pth')
parser.add_argument('--imagedir', type=str)
parser.add_argument('--calib', type=str)
parser.add_argument('--stride', type=int, default=2)
parser.add_argument('--skip', type=int, default=0)
parser.add_argument('--buffer', type=int, default=2048)
parser.add_argument('--config', default="config/default.yaml")
parser.add_argument('--timeit', action='store_true')
parser.add_argument('--viz', action="store_true")
parser.add_argument('--plot', action="store_true")
parser.add_argument('--save_reconstruction', action="store_true")
parser.add_argument('--save_trajectory', action="store_true")
args = parser.parse_args()
cfg.merge_from_file(args.config)
cfg.BUFFER_SIZE = args.buffer
print("Running with config...")
print(cfg)
pred_traj = run(cfg, args.network, args.imagedir, args.calib, args.stride, args.skip, args.viz, args.timeit, args.save_reconstruction)
name = Path(args.imagedir).stem
if args.save_reconstruction:
pred_traj, ply_data = pred_traj
ply_data.write(f"{name}.ply")
print(f"Saved {name}.ply")
if args.save_trajectory:
Path("saved_trajectories").mkdir(exist_ok=True)
save_trajectory_tum_format(pred_traj, f"saved_trajectories/{name}.txt")
if args.plot:
Path("trajectory_plots").mkdir(exist_ok=True)
plot_trajectory(pred_traj, title=f"DPVO Trajectory Prediction for {name}", filename=f"trajectory_plots/{name}.pdf")