-
Notifications
You must be signed in to change notification settings - Fork 6
/
test.py
65 lines (44 loc) · 1.89 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import os
from option import args
os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda_name
import torch
import torch.optim as optim
import torch.nn as nn
from data import dataset
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from utils.metric import torch_psnr
import utils.util as util
import torchvision
import models
writer = SummaryWriter('./logs/{}'.format(args.writer_name))
testdata1 = dataset.Data(root=os.path.join(args.data_path,'test'), args=args, train=False)
testset1 = DataLoader(testdata1, batch_size=1, shuffle=False, num_workers=1)
device = torch.device(args.device)
def to_device(sample, device):
for key, value in sample.items():
if key != 'img_name':
sample[key] = value.to(device, non_blocking=True)
return sample
def eval_model(model, dataset, name, epoch, args):
model.eval()
val_psnr_dic = 0
val_ssim_dic = 0
os.makedirs(os.path.join(args.save_path, args.writer_name, 'result-{}'.format(name)), exist_ok=True)
timer_test = util.timer()
for batch, data in enumerate(dataset):
sr = model(to_device(data, device))
psnr_c, ssim_c = torch_psnr(data['img_gt'], sr['img_out'])
val_psnr_dic = val_psnr_dic + psnr_c
val_ssim_dic = val_ssim_dic + ssim_c
torchvision.utils.save_image(sr['img_out'][0],
os.path.join(args.save_path, args.writer_name, 'result-{}'.format(name),
'{}'.format(str(data['img_name'][0]))))
print("Epoch:{}, {}, psnr: {:.3f}".format(epoch+1, name, val_psnr_dic/(len(dataset))))
print('Forward: {:.2f}s\n'.format(timer_test.toc()))
if __name__ == "__main__":
model = models.get_model(args)
pretrained_dict = torch.load(
args.load )
model.load_state_dict(pretrained_dict)
eval_model(model, testset1, args.save_name, 0, args)