-
Notifications
You must be signed in to change notification settings - Fork 86
/
Copy pathoptions.py
119 lines (102 loc) · 4.51 KB
/
options.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import os
import os.path as osp
import logging
import yaml
from utils.util import OrderedYaml
Loader, Dumper = OrderedYaml()
def parse(opt_path, is_train=True):
with open(opt_path, mode='r') as f:
opt = yaml.load(f, Loader=Loader)
# export CUDA_VISIBLE_DEVICES
gpu_list = ','.join(str(x) for x in opt['gpu_ids'])
os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list
print('export CUDA_VISIBLE_DEVICES=' + gpu_list)
opt['is_train'] = is_train
if opt['distortion'] == 'sr':
scale = opt['scale']
# datasets
for phase, dataset in opt['datasets'].items():
phase = phase.split('_')[0]
dataset['phase'] = phase
if opt['distortion'] == 'sr':
dataset['scale'] = scale
is_lmdb = False
if dataset.get('dataroot_GT', None) is not None:
dataset['dataroot_GT'] = osp.expanduser(dataset['dataroot_GT'])
if dataset['dataroot_GT'].endswith('lmdb'):
is_lmdb = True
# if dataset.get('dataroot_GT_bg', None) is not None:
# dataset['dataroot_GT_bg'] = osp.expanduser(dataset['dataroot_GT_bg'])
if dataset.get('dataroot_LQ', None) is not None:
dataset['dataroot_LQ'] = osp.expanduser(dataset['dataroot_LQ'])
if dataset['dataroot_LQ'].endswith('lmdb'):
is_lmdb = True
dataset['data_type'] = 'lmdb' if is_lmdb else 'img'
if dataset['mode'].endswith('mc'): # for memcached
dataset['data_type'] = 'mc'
dataset['mode'] = dataset['mode'].replace('_mc', '')
# path
for key, path in opt['path'].items():
if path and key in opt['path'] and key != 'strict_load':
opt['path'][key] = osp.expanduser(path)
opt['path']['root'] = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir, osp.pardir))
if is_train:
experiments_root = osp.join(opt['path']['root'], 'experiments', opt['name'])
opt['path']['experiments_root'] = experiments_root
opt['path']['models'] = osp.join(experiments_root, 'models')
opt['path']['training_state'] = osp.join(experiments_root, 'training_state')
opt['path']['log'] = experiments_root
opt['path']['val_images'] = osp.join(experiments_root, 'val_images')
# change some options for debug mode
if 'debug' in opt['name']:
opt['train']['val_freq'] = 8
opt['logger']['print_freq'] = 1
opt['logger']['save_checkpoint_freq'] = 8
else: # test
results_root = osp.join(opt['path']['root'], 'results', opt['name'])
opt['path']['results_root'] = results_root
opt['path']['log'] = results_root
# network
if opt['distortion'] == 'sr':
if 'network_G' in opt.keys():
opt['network_G']['scale'] = scale
return opt
def dict2str(opt, indent_l=1):
'''dict to string for logger'''
msg = ''
for k, v in opt.items():
if isinstance(v, dict):
msg += ' ' * (indent_l * 2) + k + ':[\n'
msg += dict2str(v, indent_l + 1)
msg += ' ' * (indent_l * 2) + ']\n'
else:
msg += ' ' * (indent_l * 2) + k + ': ' + str(v) + '\n'
return msg
class NoneDict(dict):
def __missing__(self, key):
return None
# convert to NoneDict, which return None for missing key.
def dict_to_nonedict(opt):
if isinstance(opt, dict):
new_opt = dict()
for key, sub_opt in opt.items():
new_opt[key] = dict_to_nonedict(sub_opt)
return NoneDict(**new_opt)
elif isinstance(opt, list):
return [dict_to_nonedict(sub_opt) for sub_opt in opt]
else:
return opt
def check_resume(opt, resume_iter):
'''Check resume states and pretrain_model paths'''
logger = logging.getLogger('base')
if opt['path']['resume_state']:
if opt['path'].get('pretrain_model_G', None) is not None or opt['path'].get(
'pretrain_model_D', None) is not None:
logger.warning('pretrain_model path will be ignored when resuming training.')
opt['path']['pretrain_model_G'] = osp.join(opt['path']['models'],
'{}_G.pth'.format(resume_iter))
logger.info('Set [pretrain_model_G] to ' + opt['path']['pretrain_model_G'])
if 'gan' in opt['model']:
opt['path']['pretrain_model_D'] = osp.join(opt['path']['models'],
'{}_D.pth'.format(resume_iter))
logger.info('Set [pretrain_model_D] to ' + opt['path']['pretrain_model_D'])