-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathviewer.py
178 lines (158 loc) · 7.31 KB
/
viewer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
#
# Copyright (C) 2023, Inria
# GRAPHDECO research group, https://team.inria.fr/graphdeco
# All rights reserved.
#
# This software is free for non-commercial, research and evaluation use
# under the terms of the LICENSE.md file.
#
# For inquiries contact george.drettakis@inria.fr
#
import copy
import time
import torch
from scene import Scene
import os
from tqdm import tqdm
from os import makedirs
from gaussian_renderer import render
import torchvision
from utils.general_utils import safe_state
from argparse import ArgumentParser
from arguments import ModelParams, PipelineParams, get_combined_args
from gaussian_renderer import GaussianModel
import cv2
max_reso_pow = 7
# max_reso_pow = 5
# max_reso_pow = 1
train_reso_scales = [2**i for i in range(max_reso_pow + 1)] # 1~128
# test_reso_scales = train_reso_scales + [(2**i + 2**(i+1)) / 2 for i in range(max_reso_pow)] # 1~128, include half scales
test_reso_scales = train_reso_scales # without half scales
test_reso_scales = sorted(test_reso_scales)
full_reso_scales = sorted(list(set(train_reso_scales + test_reso_scales)))
def render_interactive(dataset: ModelParams, iteration: int, pipeline: PipelineParams,
anti_alias=False, render_once=False):
with torch.no_grad():
gaussians = GaussianModel(dataset.sh_degree)
scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False, resolution_scales=full_reso_scales)
bg_color = [1,1,1] if dataset.white_background else [0, 0, 0]
background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
# # prune gaussians far from center
# gaussians.filter_center(scene.cameras_extent)
gaussians.pre_cat_feature()
# view = scene.getTestCameras()[0]
view_idx = 0
view = copy.deepcopy(scene.getTestCameras(scale=test_reso_scales[0])[view_idx])
# view = copy.deepcopy(scene.getTestCameras(scale=test_reso_scales[5])[view_idx])
gs_scale = 1.0 # size of scale compared to the original size
fade_size = 1.0
reso_idx = 0
view_resolution = None
if anti_alias:
filter_small = True
filter_large = True
else:
filter_small = False
filter_large = False
while True:
view.cal_transform()
torch.cuda.synchronize()
time_start = time.perf_counter()
for _ in range(1):
results = render(view, gaussians, pipeline, background, scaling_modifier=gs_scale,
filter_small=filter_small, filter_large=filter_large, fade_size=fade_size)
rendering = results["render"]
acc_pixel_size = results["acc_pixel_size"]
depth = results["depth"]
if render_once:
break
torch.cuda.synchronize()
render_time = time.perf_counter() - time_start
rendering = torch.permute(rendering, (1, 2, 0)) # HWC
rendering = rendering.cpu().numpy()
rendering = cv2.cvtColor(rendering, cv2.COLOR_RGB2BGR)
# normalize acc_pixel_size
acc_pixel_size = torch.clip(acc_pixel_size / 10, 0, 1)
acc_pixel_size = acc_pixel_size.cpu().numpy()
# normalize depth
depth = torch.clip(depth / torch.max(depth), 0, 1)
depth = depth.cpu().numpy()
if view_resolution is None:
view_resolution = rendering.shape[:2]
else:
rendering = cv2.resize(rendering, view_resolution[::-1], interpolation=cv2.INTER_NEAREST)
acc_pixel_size = cv2.resize(acc_pixel_size, view_resolution[::-1], interpolation=cv2.INTER_NEAREST)
depth = cv2.resize(depth, view_resolution[::-1], interpolation=cv2.INTER_NEAREST)
cv2.imshow("acc_pixel_size", acc_pixel_size)
cv2.imshow("depth", depth)
cv2.imshow("rendering", rendering)
cv2.setWindowTitle("rendering", f"{render_time * 1000:.2f}ms")
key = cv2.waitKey(0)
if key == ord('q'):
break
elif key == ord('4'):
view.T[0] += 0.1
elif key == ord('6'):
view.T[0] -= 0.1
elif key == ord('8'):
view.T[1] += 0.1
elif key == ord('2'):
view.T[1] -= 0.1
elif key == ord('7'):
view.T[2] += 0.5
elif key == ord('9'):
view.T[2] -= 0.5
elif key == ord('1'):
view.scale *= 0.9
elif key == ord('3'):
view.scale /= 0.9
elif key == ord('['):
gs_scale = max(0.1, gs_scale - 0.1)
elif key == ord(']'):
gs_scale = min(2.0, gs_scale + 0.1)
elif key == ord(';'):
fade_size = max(0.1, fade_size - 0.1)
elif key == ord('\''):
fade_size = min(2.0, fade_size + 0.1)
elif key == ord('x'):
view_idx = (view_idx - 1) % len(scene.getTestCameras())
view = copy.deepcopy(scene.getTestCameras(scale=test_reso_scales[reso_idx])[view_idx])
elif key == ord('c'):
view_idx = (view_idx + 1) % len(scene.getTestCameras())
view = copy.deepcopy(scene.getTestCameras(scale=test_reso_scales[reso_idx])[view_idx])
elif key == ord('z'):
view = copy.deepcopy(scene.getTestCameras(scale=test_reso_scales[reso_idx])[view_idx])
gs_scale = 1.0
fade_size = 1.0
reso_idx = 0
elif key == ord('/'):
reso_idx = min(reso_idx + 1, max_reso_pow)
view = copy.deepcopy(scene.getTestCameras(scale=test_reso_scales[reso_idx])[view_idx])
elif key == ord('.'):
reso_idx = max(reso_idx - 1, 0)
view = copy.deepcopy(scene.getTestCameras(scale=test_reso_scales[reso_idx])[view_idx])
elif key == ord('s'):
# save figure
output_dir = dataset.model_path
output_dir = os.path.join(output_dir, "viewer")
makedirs(output_dir, exist_ok=True)
image_name = f"{view_idx}_{reso_idx}.png"
image_path = os.path.join(output_dir, image_name)
cv2.imwrite(image_path, rendering * 255)
print(f"image saved to {image_path}")
if __name__ == "__main__":
# Set up command line argument parser
parser = ArgumentParser(description="Testing script parameters")
model = ModelParams(parser, sentinel=True)
pipeline = PipelineParams(parser)
parser.add_argument("--iteration", default=-1, type=int)
parser.add_argument("--quiet", action="store_true")
parser.add_argument("--anti_alias", action="store_true", default=False)
parser.add_argument("--render_once", action="store_true", default=False)
args = get_combined_args(parser)
print("Rendering " + args.model_path)
# Initialize system state (RNG)
safe_state(args.quiet)
# render_sets(model.extract(args), args.iteration, pipeline.extract(args), args.skip_train, args.skip_test)
render_interactive(model.extract(args), args.iteration, pipeline.extract(args),
anti_alias=args.anti_alias, render_once=args.render_once)