-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
395 lines (326 loc) · 16.5 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
import os
import time
import json
import random
import argparse
import datetime
import numpy as np
import torch
import torch.backends.cudnn as cudnn
import torch.distributed as dist
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
from timm.utils import accuracy, AverageMeter
from models import build_model
from utils.lr_scheduler import build_scheduler
from utils.optimizer import build_optimizer
from utils.logger import create_logger
from utils.utils import NativeScalerWithGradNormCount, auto_resume_helper, reduce_tensor
from utils.utils import load_checkpoint_ema, load_pretrained_ema, save_checkpoint_ema
from fvcore.nn import FlopCountAnalysis, flop_count_str, flop_count
from timm.utils import ModelEma as ModelEma
print(f"||{torch.multiprocessing.get_start_method()}||", end="")
torch.multiprocessing.set_start_method("spawn", force=True)
from datautil.getdataloader import get_img_dataloader
from utils.dgutil import train_valid_target_eval_names, DG_accuracy, img_param_init
from config import get_config
def str2bool(v):
"""
Converts string to bool type; enables command line
arguments in the format of '--arg1 true --arg2 false'
"""
if isinstance(v, bool):
return v
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Boolean value expected.')
def parse_option():
parser = argparse.ArgumentParser('DGMamba training and evaluation script', add_help=False)
parser.add_argument('--cfg', type=str, required=True, metavar="FILE", help='path to config file', )
parser.add_argument(
"--opts",
help="Modify config options by adding 'KEY VALUE' pairs. ",
default=None,
nargs='+',
)
# easy config modification
parser.add_argument('--batch-size', type=int, help="batch size for single GPU")
parser.add_argument('--data-path', type=str, help='path to dataset')
parser.add_argument('--zip', action='store_true', help='use zipped dataset instead of folder dataset')
parser.add_argument('--cache-mode', type=str, default='part', choices=['no', 'full', 'part'],
help='no: no cache, '
'full: cache all data, '
'part: sharding the dataset into nonoverlapping pieces and only cache one piece')
parser.add_argument('--pretrained',
help='pretrained weight from checkpoint, could be imagenet22k pretrained weight')
parser.add_argument('--resume', help='resume from checkpoint')
parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps")
parser.add_argument('--use-checkpoint', action='store_true',
help="whether to use gradient checkpointing to save memory")
parser.add_argument('--disable_amp', action='store_true', help='Disable pytorch amp')
parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
parser.add_argument('--throughput', action='store_true', help='Test throughput only')
parser.add_argument('--fused_layernorm', action='store_true', help='Use fused layernorm.')
parser.add_argument('--optim', type=str, help='overwrite optimizer if provided, can be adamw/sgd.')
# EMA related parameters
parser.add_argument('--model_ema', type=str2bool, default=False)
parser.add_argument('--model_ema_decay', type=float, default=0.9999, help='')
parser.add_argument('--model_ema_force_cpu', type=str2bool, default=False, help='')
parser.add_argument('--memory_limit_rate', type=float, default=-1, help='limitation of gpu memory use')
# DG paramaters
parser.add_argument('--dataset', type=str, default='PACS')
parser.add_argument('--test_envs', type=int, nargs='+', default=0)
parser.add_argument('--split_style', type=str, default='strat')
parser.add_argument('--lr', type=float, default=None)
parser.add_argument('--algorithm', type=str, default='DGMamba')
parser.add_argument('--alpha', type=float, default=0.25)
parser.add_argument('--output', type=str, default=None)
args, unparsed = parser.parse_known_args()
args = img_param_init(args)
config = get_config(args)
return args, config
def main(config):
data_loader_train, data_loader_val, mixup_fn = get_img_dataloader(config)
logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}")
model = build_model(config)
model.cuda()
model_without_ddp = model
model_ema = None
if args.model_ema:
# Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
model_ema = ModelEma(
model,
decay=args.model_ema_decay,
device='cpu' if args.model_ema_force_cpu else '',
resume='')
print("Using EMA with decay = %.8f" % args.model_ema_decay)
optimizer = build_optimizer(config, model, logger)
model = torch.nn.parallel.DistributedDataParallel(model, broadcast_buffers=False)
loss_scaler = NativeScalerWithGradNormCount()
if config.TRAIN.ACCUMULATION_STEPS > 1:
lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train) // config.TRAIN.ACCUMULATION_STEPS)
else:
lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train))
if config.AUG.MIXUP > 0.:
# smoothing is handled with mixup label transform
criterion = SoftTargetCrossEntropy()
elif config.MODEL.LABEL_SMOOTHING > 0.:
criterion = LabelSmoothingCrossEntropy(smoothing=config.MODEL.LABEL_SMOOTHING)
else:
criterion = torch.nn.CrossEntropyLoss()
best_valid_acc, best_target_acc = 0, 0
best_valid_acc_ema, best_target_acc_ema = 0, 0
best_epoch, best_epoch_ema = 0, 0
if config.TRAIN.AUTO_RESUME:
resume_file = auto_resume_helper(config.OUTPUT)
if resume_file:
if config.MODEL.RESUME:
logger.warning(f"auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}")
config.defrost()
config.MODEL.RESUME = resume_file
config.freeze()
logger.info(f'auto resuming from {resume_file}')
else:
logger.info(f'no checkpoint found in {config.OUTPUT}, ignoring auto resume')
if config.MODEL.RESUME:
max_accuracy, max_accuracy_ema = load_checkpoint_ema(config, model_without_ddp, optimizer, lr_scheduler, loss_scaler, logger, model_ema)
train_acc, valid_acc, target_acc = validate(config, data_loader_val, model)
logger.info(f"Accuracy of the network on the TARGET test images: {target_acc:.1f}%")
if model_ema is not None:
train_acc_ema, valid_acc_ema, target_acc_ema = validate(config, data_loader_val, model_ema.ema)
logger.info(f"Accuracy of the network ema on the TARGET test images: {target_acc_ema:.1f}%")
if config.EVAL_MODE:
return
if config.MODEL.PRETRAINED and (not config.MODEL.RESUME):
load_pretrained_ema(config, model_without_ddp, logger, model_ema)
if model_ema is not None:
train_acc_ema, valid_acc_ema, target_acc_ema = validate(config, data_loader_val, model_ema.ema)
logger.info(f"Accuracy of the network ema on the TARGET test images: {target_acc_ema:.1f}%")
if config.THROUGHPUT_MODE:
throughput(data_loader_val, model, logger)
if model_ema is not None:
throughput(data_loader_val, model_ema.ema, logger)
return
logger.info("Start training")
start_time = time.time()
for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS):
train_one_epoch(config, model, criterion, data_loader_train, optimizer, epoch, mixup_fn, lr_scheduler, loss_scaler, model_ema)
train_acc, valid_acc, target_acc = validate(config, data_loader_val, model)
logger.info(f"Accuracy of the network on the TARGET test images: {target_acc:.1f}%")
if valid_acc >= best_valid_acc:
if valid_acc == best_valid_acc:
best_target_acc = max(best_target_acc, target_acc)
else:
best_valid_acc = valid_acc
best_target_acc = target_acc
best_epoch = epoch
logger.info(f'Epoch: {epoch} Best epoch: {best_epoch} DG target accuracy: {best_target_acc:.4f}%')
if model_ema is not None:
train_acc_ema, valid_acc_ema, target_acc_ema = validate(config, data_loader_val, model_ema.ema)
logger.info(f"Accuracy of the network ema on the TARGET test images: {target_acc_ema:.1f}%")
if valid_acc_ema > best_valid_acc_ema:
best_valid_acc_ema = valid_acc_ema
best_target_acc_ema = target_acc_ema
best_epoch_ema = epoch
logger.info(f'Epoch: {epoch} Best epoch ema: {best_epoch_ema} DG target accuracy of ema: {best_target_acc_ema:.4f}%')
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
logger.info('Training time {}'.format(total_time_str))
def train_one_epoch(config, model, criterion, data_loader, optimizer, epoch, mixup_fn, lr_scheduler, loss_scaler, model_ema=None):
model.train()
optimizer.zero_grad()
num_steps = config.steps_per_epoch
batch_time = AverageMeter()
data_time = AverageMeter()
loss_meter = AverageMeter()
norm_meter = AverageMeter()
scaler_meter = AverageMeter()
start = time.time()
end = time.time()
train_minibatches_iterator = zip(*data_loader)
for idx in range(config.steps_per_epoch):
minibatches = [(data) for data in next(train_minibatches_iterator)]
samples = torch.cat([data[0].float() for data in minibatches])
targets = torch.cat([data[1].long() for data in minibatches])
samples = samples.cuda(non_blocking=True)
targets = targets.cuda(non_blocking=True)
y = targets
if mixup_fn is not None:
samples, targets = mixup_fn(samples, targets)
if config.alpha >= 0 and epoch >= 2:
all_targets = torch.cat((targets, targets))
else:
all_targets = targets
data_time.update(time.time() - end)
with torch.cuda.amp.autocast(enabled=config.AMP_ENABLE):
outputs = model(samples, y, epoch)
loss = criterion(outputs, all_targets)
loss = loss / config.TRAIN.ACCUMULATION_STEPS
# this attribute is added by timm on one optimizer (adahessian)
is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
grad_norm = loss_scaler(loss, optimizer, clip_grad=config.TRAIN.CLIP_GRAD,
parameters=model.parameters(), create_graph=is_second_order,
update_grad=(idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0)
if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0:
optimizer.zero_grad()
lr_scheduler.step_update((epoch * num_steps + idx) // config.TRAIN.ACCUMULATION_STEPS)
if model_ema is not None:
model_ema.update(model)
loss_scale_value = loss_scaler.state_dict()["scale"]
torch.cuda.synchronize()
loss_meter.update(loss.item(), all_targets.size(0))
if grad_norm is not None: # loss_scaler return None if not update
norm_meter.update(grad_norm)
scaler_meter.update(loss_scale_value)
batch_time.update(time.time() - end)
end = time.time()
if idx % config.PRINT_FREQ == 0:
lr = optimizer.param_groups[0]['lr']
wd = optimizer.param_groups[0]['weight_decay']
memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
etas = batch_time.avg * (num_steps - idx)
logger.info(
f'Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t'
f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t wd {wd:.4f}\t'
f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t'
f'data time {data_time.val:.4f} ({data_time.avg:.4f})\t'
f'loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t'
f'grad_norm {norm_meter.val:.4f} ({norm_meter.avg:.4f})\t'
f'loss_scale {scaler_meter.val:.4f} ({scaler_meter.avg:.4f})\t'
f'mem {memory_used:.0f}MB')
epoch_time = time.time() - start
logger.info(f"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}")
@torch.no_grad()
def validate(config, data_loader, model):
criterion = torch.nn.CrossEntropyLoss()
acc_record = {}
eval_name_dict = train_valid_target_eval_names(config)
acc_type_list = ['train', 'valid', 'target']
for item in acc_type_list:
acc_record[item] = np.mean(np.array(
[DG_accuracy(model, data_loader[i]) for i in eval_name_dict[item]]
))
logger.info(f" * train {acc_record['train']:.4f} valid {acc_record['valid']:.4f} target {acc_record['target']:.4f}")
return acc_record['train'], acc_record['valid'], acc_record['target']
@torch.no_grad()
def throughput(data_loader, model, logger):
model.eval()
for idx, (images, _) in enumerate(data_loader):
images = images.cuda(non_blocking=True)
batch_size = images.shape[0]
for i in range(50):
model(images)
torch.cuda.synchronize()
logger.info(f"throughput averaged with 30 times")
tic1 = time.time()
for i in range(30):
model(images)
torch.cuda.synchronize()
tic2 = time.time()
logger.info(f"batch_size {batch_size} throughput {30 * batch_size / (tic2 - tic1)}")
return
if __name__ == '__main__':
args, config = parse_option()
if config.AMP_OPT_LEVEL:
print("[warning] Apex amp has been deprecated, please use pytorch amp instead!")
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
rank = int(os.environ["RANK"])
world_size = int(os.environ['WORLD_SIZE'])
print(f"RANK and WORLD_SIZE in environ: {rank}/{world_size}")
else:
rank = -1
world_size = -1
torch.cuda.set_device(rank)
dist.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank)
dist.barrier()
seed = config.SEED + dist.get_rank()
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
cudnn.benchmark = False
cudnn.deterministic = True
# linear scale the learning rate according to total batch size, may not be optimal
linear_scaled_lr = config.TRAIN.BASE_LR * dist.get_world_size()
linear_scaled_warmup_lr = config.TRAIN.WARMUP_LR * dist.get_world_size()
linear_scaled_min_lr = config.TRAIN.MIN_LR * dist.get_world_size()
# gradient accumulation also need to scale the learning rate
if config.TRAIN.ACCUMULATION_STEPS > 1:
linear_scaled_lr = linear_scaled_lr * config.TRAIN.ACCUMULATION_STEPS
linear_scaled_warmup_lr = linear_scaled_warmup_lr * config.TRAIN.ACCUMULATION_STEPS
linear_scaled_min_lr = linear_scaled_min_lr * config.TRAIN.ACCUMULATION_STEPS
config.defrost()
config.TRAIN.BASE_LR = linear_scaled_lr
config.TRAIN.WARMUP_LR = linear_scaled_warmup_lr
config.TRAIN.MIN_LR = linear_scaled_min_lr
config.freeze()
# to make sure all the config.OUTPUT are the same
config.defrost()
if dist.get_rank() == 0:
obj = [config.OUTPUT]
# obj = [str(random.randint(0, 100))] # for test
else:
obj = [None]
dist.broadcast_object_list(obj)
dist.barrier()
config.OUTPUT = obj[0]
print(config.OUTPUT, flush=True)
config.freeze()
os.makedirs(config.OUTPUT, exist_ok=True)
logger = create_logger(output_dir=config.OUTPUT, dist_rank=dist.get_rank(), name=f"{config.MODEL.NAME}")
if dist.get_rank() == 0:
path = os.path.join(config.OUTPUT, "config.json")
with open(path, "w") as f:
f.write(config.dump())
logger.info(f"Full config saved to {path}")
# print config
logger.info(config.dump())
logger.info(json.dumps(vars(args)))
if args.memory_limit_rate > 0 and args.memory_limit_rate < 1:
torch.cuda.set_per_process_memory_fraction(args.memory_limit_rate)
usable_memory = torch.cuda.get_device_properties(0).total_memory * args.memory_limit_rate / 1e6
print(f"===========> GPU memory is limited to {usable_memory}MB", flush=True)
main(config)