-
Notifications
You must be signed in to change notification settings - Fork 18
/
train_sampler.py
116 lines (91 loc) · 3.85 KB
/
train_sampler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import argparse
import logging
import os
import os.path as osp
import random
import time
import torch
from data.sample_identity_dataset import SampleIdentityDataset
from models import create_model
from utils.logger import MessageLogger, get_root_logger, init_tb_logger
from utils.options import dict2str, dict_to_nonedict, parse
from utils.util import make_exp_dirs, set_random_seed
def main():
# options
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.')
args = parser.parse_args()
opt = parse(args.opt, is_train=True)
# mkdir and loggers
make_exp_dirs(opt)
log_file = osp.join(opt['path']['log'], f"train_{opt['name']}.log")
logger = get_root_logger(
logger_name='base', log_level=logging.INFO, log_file=log_file)
logger.info(dict2str(opt))
# initialize tensorboard logger
tb_logger = None
if opt['use_tb_logger'] and 'debug' not in opt['name']:
tb_logger = init_tb_logger(log_dir='./tb_logger/' + opt['name'])
# random seed
seed = opt['manual_seed']
if seed is None:
seed = random.randint(1, 10000)
logger.info(f'Random seed: {seed}')
set_random_seed(seed)
# convert to NoneDict, which returns None for missing keys
opt = dict_to_nonedict(opt)
# set up data loader
train_dataset = SampleIdentityDataset(opt['datasets']['train'])
train_loader = torch.utils.data.DataLoader(
dataset=train_dataset,
batch_size=opt['batch_size'],
shuffle=True,
num_workers=opt['num_workers'],
persistent_workers=True,
drop_last=True)
logger.info(f'Number of train set: {len(train_dataset)}.')
opt['max_iters'] = opt['num_epochs'] * len(
train_dataset) // opt['batch_size']
val_dataset = SampleIdentityDataset(opt['datasets']['val'])
val_loader = torch.utils.data.DataLoader(
dataset=val_dataset, batch_size=opt['batch_size'], shuffle=False)
logger.info(f'Number of val set: {len(val_dataset)}.')
test_dataset = SampleIdentityDataset(opt['datasets']['test'])
test_loader = torch.utils.data.DataLoader(
dataset=test_dataset, batch_size=opt['batch_size'], shuffle=False)
logger.info(f'Number of test set: {len(test_dataset)}.')
current_iter = 0
model = create_model(opt)
data_time, iter_time = 0, 0
current_iter = 0
# create message logger (formatted outputs)
msg_logger = MessageLogger(opt, current_iter, tb_logger)
for epoch in range(opt['num_epochs']):
lr = model.update_learning_rate(epoch, current_iter)
for _, batch_data in enumerate(train_loader):
data_time = time.time() - data_time
current_iter += 1
model.feed_data(batch_data)
model.optimize_parameters()
iter_time = time.time() - iter_time
if current_iter % opt['print_freq'] == 0:
log_vars = {'epoch': epoch, 'iter': current_iter}
log_vars.update({'lrs': [lr]})
log_vars.update({'time': iter_time, 'data_time': data_time})
log_vars.update(model.get_current_log())
msg_logger(log_vars)
data_time = time.time()
iter_time = time.time()
if epoch % opt['val_freq'] == 0 and epoch != 0:
save_dir = f'{opt["path"]["visualization"]}/valset/epoch_{epoch:03d}' # noqa
os.makedirs(save_dir, exist_ok=opt['debug'])
model.inference(val_loader, save_dir)
save_dir = f'{opt["path"]["visualization"]}/testset/epoch_{epoch:03d}' # noqa
os.makedirs(save_dir, exist_ok=opt['debug'])
model.inference(test_loader, save_dir)
# save model
model.save_network(
model._denoise_fn,
f'{opt["path"]["models"]}/sampler_epoch{epoch}.pth')
if __name__ == '__main__':
main()