-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain_test_SupernovaSR.py
97 lines (73 loc) · 3.5 KB
/
main_test_SupernovaSR.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import os.path
import logging
import torch
from utils import utils_logger
from utils import utils_image as util
# from utils import utils_model
from models.network_rrdbnet import RRDBNet as net
def main():
utils_logger.logger_info('blind_sr_log', log_path='blind_sr_log.log')
logger = logging.getLogger('blind_sr_log')
# print(torch.__version__) # pytorch version
# print(torch.version.cuda) # cuda version
# print(torch.backends.cudnn.version()) # cudnn version
testsets = 'testsets' # fixed, set path of testsets
testset_Ls = ['RealSRSet'] # ['RealSRSet','DPED']
model_names = ['RRDB','ESRGAN','FSSR_DPED','FSSR_JPEG','RealSR_DPED','RealSR_JPEG']
model_names = ['SupernovaSR'] # 'SupernovaSRx2' for scale factor 2
save_results = True
sf = 4
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
for model_name in model_names:
if model_name in ['SupernovaSRx2']:
sf = 2
model_path = os.path.join('model_zoo', model_name+'.pth') # set model path
logger.info('{:>16s} : {:s}'.format('Model Name', model_name))
# torch.cuda.set_device(0) # set GPU ID
logger.info('{:>16s} : {:<d}'.format('GPU ID', torch.cuda.current_device()))
torch.cuda.empty_cache()
# --------------------------------
# define network and load model
# --------------------------------
model = net(in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=sf) # define network
# model_old = torch.load(model_path)
# state_dict = model.state_dict()
# for ((key, param),(key2, param2)) in zip(model_old.items(), state_dict.items()):
# state_dict[key2] = param
# model.load_state_dict(state_dict, strict=True)
model.load_state_dict(torch.load(model_path), strict=True)
model.eval()
for k, v in model.named_parameters():
v.requires_grad = False
model = model.to(device)
torch.cuda.empty_cache()
for testset_L in testset_Ls:
L_path = os.path.join(testsets, testset_L)
#E_path = os.path.join(testsets, testset_L+'_'+model_name)
E_path = os.path.join(testsets, testset_L+'_results_x'+str(sf))
util.mkdir(E_path)
logger.info('{:>16s} : {:s}'.format('Input Path', L_path))
logger.info('{:>16s} : {:s}'.format('Output Path', E_path))
idx = 0
for img in util.get_image_paths(L_path):
# --------------------------------
# (1) img_L
# --------------------------------
idx += 1
i;;;;;mg_name, ext = os.path.splitext(os.path.basename(img))
logger.info('{:->4d} --> {:<s} --> x{:<d}--> {:<s}'.format(idx, model_name, sf, img_name+ext))
img_L = util.imread_uint(img, n_channels=3)
img_L = util.uint2tensor4(img_L)
img_L = img_L.to(device)
# --------------------------------
# (2) inference
# --------------------------------
img_E = model(img_L)
# --------------------------------
# (3) img_E
# --------------------------------
img_E = util.tensor2uint(img_E)
if save_results:
util.imsave(img_E, os.path.join(E_path, img_name+'_'+model_name+'.png'))
if __name__ == '__main__':
main()