-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathufo_train.py
232 lines (205 loc) · 9.06 KB
/
ufo_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
#!/usr/bin/env python
# Copyright (c) Baidu, Inc. and its affiliates.
"""
This training script is mainly constructed on train_net.py.
Additionally, this script is specialized for the training of supernet.
Moreover, this script adds a function of self-distillation.
If specifing `teacher_model_path` in the given config file, teacher model will
be built, otherwise teacher model is None.
"""
import logging
import os.path
import sys
import os
os.environ['CUDA_VISIBLE_DEVICES']='0,1'
import paddle
import numpy as np
sys.path.append('/home/shangzaixing/code/PAZHOUbase/')
print(sys.path)
SEED = os.getenv("SEED", "0")
paddle.seed(42)
# np.random.seed(int(SEED))
from utils.events import CommonMetricSacredWriter
from engine.hooks import LRScheduler
from utils.config import auto_adjust_cfg
from fastreid.utils.checkpoint import Checkpointer
from detectron2.config import LazyConfig, instantiate
from detectron2.engine import (
AMPTrainer,
SimpleTrainer,
default_argument_parser,
default_setup,
default_writers,
hooks,
)
from evaluation import print_csv_format
from evaluation.evaluator import inference_on_dataset
from evaluation.seg_evaluator import seg_inference_on_dataset, seg_inference_on_test_dataset
from utils import comm
import paddle.distributed as dist
logger = logging.getLogger("ufo")
def do_test(cfg, model, _run=None, subnet_mode="largest"):
if "evaluator" in cfg.dataloader:
dataloaders = instantiate(cfg.dataloader.test)
rets = {}
for idx, (dataloader, evaluator_cfg) in enumerate(zip(dataloaders, cfg.dataloader.evaluator)):
task_name = '.'.join(list(dataloader.task_loaders.keys()))
dataset_name = dataloader.task_loaders[task_name].dataset.dataset_name
if (hasattr(cfg.train, 'selected_task_names')) and (task_name not in cfg.train.selected_task_names):
continue
print('=' * 10, dataset_name, '=' * 10)
# recognition
if hasattr(list(dataloader.task_loaders.values())[0].dataset, 'num_query'):
evaluator_cfg.num_query = list(dataloader.task_loaders.values())[0].dataset.num_query
evaluator_cfg.num_valid_samples = list(dataloader.task_loaders.values())[0].dataset.num_valid_samples
evaluator_cfg.labels = list(dataloader.task_loaders.values())[0].dataset.labels
evaluator = instantiate(evaluator_cfg)
ret = inference_on_dataset(model, dataloader, evaluator)
# segmentation
elif dataset_name in['Cityscapes', 'BDD100K', 'InferDataset']:
evaluator = instantiate(evaluator_cfg)
if evaluator.mode == 'train':
ret = seg_inference_on_dataset(model, dataloader, evaluator)
else:
print("seg_inference_on_test_dataset")
ret = seg_inference_on_test_dataset(model, dataloader, evaluator)
# detection
else:
evaluator_cfg.anno_file = list(dataloader.task_loaders.values())[0].dataset.get_anno()
evaluator_cfg.clsid2catid = {v: k for k, v in list(dataloader.task_loaders.values())[0].dataset.catid2clsid.items()}
evaluator = instantiate(evaluator_cfg)
ret = inference_on_dataset(model, dataloader, evaluator)
print_csv_format(ret)
for metric, res in ret.items():
rets['{}.{}.{}'.format(task_name, dataset_name, metric)] = res
if _run is not None:
_run.log_scalar('{}.{}.{}'.format(task_name, dataset_name, metric), res, )
print('{}.{}.{}'.format(task_name, dataset_name, metric), res)
return rets
def do_train(args, cfg, cfg_for_sacred=None, _run=None):
"""
Args:
cfg: an object with the following attributes:
model: instantiate to a module
dataloader.{train,test}: instantiate to dataloaders
dataloader.evaluator: instantiate to evaluator for test set
optimizer: instantaite to an optimizer
lr_multiplier: instantiate to a fvcore scheduler
train: other misc config defined in `configs/common/train.py`, including:
output_dir (str)
init_checkpoint (str)
amp.enabled (bool)
max_iter (int)
eval_period, log_period (int)
device (str)
checkpointer (dict)
ddp (dict)
"""
logger = logging.getLogger("ufo")
train_loader = instantiate(cfg.dataloader.train)
auto_adjust_cfg(cfg, train_loader)
logger.info(cfg)
model = instantiate(cfg.model)
logger.info("Model:\n{}".format(model))
model.to(cfg.train.device)
cfg.optimizer.model = model
optim = instantiate(cfg.optimizer)
logger.info("Optim:\n{}".format(optim))
if paddle.distributed.get_world_size() > 1:
model = paddle.DataParallel(model,find_unused_parameters=True)
trainer = (AMPTrainer if cfg.train.amp.enabled else SimpleTrainer)(model, train_loader, optim)
checkpointer = Checkpointer(
model,
cfg.train.output_dir,
optimizer=optim,
trainer=trainer,
)
# # set optimizer
# cfg.lr_multiplier.optimizer = optim
trainer.register_hooks(
[
hooks.IterationTimer(),
LRScheduler(optimizer=optim, scheduler=optim._learning_rate),
hooks.PeriodicCheckpointer(checkpointer, **cfg.train.checkpointer)
if comm.is_main_process()
else None,
# hooks.EvalHook(cfg.train.eval_period, lambda: do_test(cfg, model, _run)),
hooks.PeriodicWriter(
default_writers(cfg.train.output_dir, cfg.train.max_iter),
period=cfg.train.log_period,
)
if comm.is_main_process()
else None,
hooks.PeriodicWriter(
[CommonMetricSacredWriter(_run, cfg.train.max_iter)],
period=cfg.train.log_period,
)
if comm.is_main_process() and _run is not None
else None,
]
)
checkpointer.resume_or_load(cfg.train.init_checkpoint, resume=args.resume)
if args.resume and checkpointer.has_checkpoint():
# The checkpoint stores the training iteration that just finished, thus we start
# at the next iteration
start_iter = trainer.iter + 1
else:
start_iter = 0
trainer.train(start_iter, cfg.train.max_iter)
def main(args):
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ['WORLD_SIZE'])
args.gpu = int(os.environ['LOCAL_RANK'])
print('rank is {} , world_size is {}, gpu is {} '.format(args.rank, args.world_size, args.gpu))
paddle.set_device('gpu')
rank = paddle.distributed.get_rank()
print('rank is {}, world size is {}'.format(rank, paddle.distributed.get_world_size()))
if paddle.distributed.get_world_size() > 1:
paddle.distributed.init_parallel_env()
cfg = LazyConfig.load(args.config_file)
cfg = LazyConfig.apply_overrides(cfg, args.opts)
default_setup(cfg, args)
if args.eval_only:
train_loader = instantiate(cfg.dataloader.train)
auto_adjust_cfg(cfg, train_loader)
model = instantiate(cfg.model)
model.to(cfg.train.device)
if paddle.distributed.get_world_size() > 1:
model = paddle.DataParallel(model)
Checkpointer(model).load(cfg.train.init_checkpoint)
print(do_test(cfg, model))
else:
if cfg.train.sacred.enabled and comm.is_main_process():
from sacred import Experiment
from sacred.observers import MongoObserver
from sacred.observers import FileStorageObserver
ex = Experiment(cfg.train.output_dir)
mongo_url = None
# do not add mongo.txt into git repo
if os.path.exists('mongo.txt'):
with open('mongo.txt', 'r') as fin:
mongo_url = fin.readline().strip()
else:
print('mongo.txt does not exists, use file observer instead')
if mongo_url is not None:
ex.observers.append(MongoObserver(mongo_url))
else:
ex.observers.append(FileStorageObserver(cfg.train.output_dir))
@ex.config
def train_cfg():
args = None
cfg_for_sacred = None
# sacred will convert `cfg` to a new datatype: `ConfigScope`
# we want to keep the original datatype and therefore only keep a dict record, `cfg_`
def do_train_sacred(args, cfg_for_sacred, _run):
do_train(args, cfg, cfg_for_sacred, _run)
ex.add_source_file(args.config_file)
ex.main(do_train_sacred)
ex.run(config_updates={'args': args, 'cfg_for_sacred': LazyConfig.to_dict(cfg)})
else:
do_train(args, cfg)
if __name__ == "__main__":
args = default_argument_parser().parse_args()
main(args)
# dist.spawn(main(args),gpus='0,1')