-
Notifications
You must be signed in to change notification settings - Fork 36
/
Copy pathrender_texture.py
69 lines (58 loc) · 2.63 KB
/
render_texture.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import argparse
import cv2
import numpy as np
import os
from skimage import img_as_ubyte
import sys
import tqdm
import torch
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import config
from dataset.eval_dataset import EvalDataset
from model.pipeline import PipeLine
parser = argparse.ArgumentParser()
parser.add_argument('--data', type=str, default=config.DATA_DIR, help='directory to data')
parser.add_argument('--test', default=config.TEST_SET, help='index list of test uv_maps')
parser.add_argument('--checkpoint', type=str, default=config.CHECKPOINT_DIR, help='directory to save checkpoint')
parser.add_argument('--load', type=str, default=config.TEST_LOAD, help='checkpoint name')
parser.add_argument('--batch', type=int, default=config.BATCH_SIZE)
parser.add_argument('--save', type=str, default=config.SAVE_DIR, help='save directory')
parser.add_argument('--out_mode', type=str, default=config.OUT_MODE, choices=('video', 'image'))
parser.add_argument('--fps', type=int, default=config.FPS)
args = parser.parse_args()
if __name__ == '__main__':
checkpoint_file = os.path.join(args.checkpoint, args.load)
if not os.path.exists(checkpoint_file):
print('checkpoint not exists!')
sys.exit()
if not os.path.exists(args.save):
os.makedirs(args.save)
dataset = EvalDataset(args.data, args.test, False)
dataloader = DataLoader(dataset, batch_size=args.batch, shuffle=False, num_workers=4, collate_fn=EvalDataset.get_collect_fn(False))
model = torch.load(checkpoint_file)
model = model.to('cuda')
model.eval()
torch.set_grad_enabled(False)
if args.out_mode == 'video':
fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v')
writer = cv2.VideoWriter(os.path.join(args.save, 'render.mp4'), fourcc, 16,
(dataset.width, dataset.height), True)
print('Evaluating started')
for samples in tqdm.tqdm(dataloader):
uv_maps, masks, idxs = samples
preds = model(uv_maps.cuda()).cpu()
preds.masked_fill_(masks, 0) # fill invalid with 0
# save result
if args.out_mode == 'video':
preds = preds.numpy()
preds = np.clip(preds, -1.0, 1.0)
for i in range(len(idxs)):
image = img_as_ubyte(preds[i])
image = np.transpose(image, (1,2,0))
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
writer.write(image)
else:
for i in range(len(idxs)):
image = transforms.ToPILImage()(preds[i])
image.save(os.path.join(args.save, '{}_render.png'.format(idxs[i])))