-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathvalidate.py
129 lines (106 loc) · 6.28 KB
/
validate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import argparse
import os
import time
import datetime
import torch.nn.parallel
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
from torch import nn
from utils_clom.Dataloader import get_dataset
from utils_clom.model_pool import get_network
from utils_clom.trainer import train, validate
from utils_clom.utils import set_random_seed, set_gpu, get_logger, get_pretrained_model_root, ParamDiffAug, load_syn_data
def main(args):
args.random_seed = int(time.time() * 1000) % 100000 if args.random_seed is None else args.random_seed
set_random_seed(args.random_seed)
if args.data_path is None:
args.data_path = os.path.join('data', args.dataset)
if not os.path.isdir(args.data_path):
os.makedirs(args.data_path, exist_ok=True)
save_root = get_pretrained_model_root(args.dataset, args.model, args.method, args.ipc)
if not os.path.isdir(save_root):
os.makedirs(save_root, exist_ok=True)
print("data path:", args.data_path)
index = 0
while(os.path.exists(os.path.join(save_root, "{}_{}_{}_IPC{}_{}.pt".format(args.dataset, args.model, args.method, args.ipc, index)))):
index = index + 1
save_name = "{}_{}_{}_IPC{}_{}.pt".format(args.dataset, args.model, args.method, args.ipc, index)
logger_name = "{}_{}_{}_IPC{}_{}_log.log".format(args.dataset, args.model, args.method, args.ipc, index)
# log information
logger, file_handler, stream_handler = get_logger(os.path.join(save_root,logger_name))
logger.info("Time: "+datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S'))
logger.info("Dataset: "+args.dataset)
logger.info("method: "+args.method)
logger.info("ipc: "+str(args.ipc))
logger.info("model: "+args.model)
logger.info("random seed: "+str(args.random_seed))
logger.info("epochs: "+str(args.epochs))
logger.info("batch size: "+str(args.batch_size))
logger.info("lr: "+str(args.lr))
logger.info("momentum: "+str(args.momentum))
logger.info("weight decay: "+str(args.weight_decay))
logger.info("lr decay step: "+args.lr_decay_step)
logger.info("normalize data: "+str(args.normalize_data))
logger.info("dsa: "+str(args.dsa))
logger.info("dsa strategy: "+args.dsa_strategy)
# load original dataset
args.dsa_param = ParamDiffAug()
channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test, testloader = get_dataset(args.dataset, args.data_path)
# get synthesis dataset
syn_dst , _, _ = load_syn_data(args.synthesis_data_path)
trainloader = torch.utils.data.DataLoader(syn_dst, batch_size=args.batch_size, shuffle=True, num_workers=0)
# get model
model = get_network(args.model, channel, num_classes, im_size)
model = set_gpu(args, model)
# train settings
criterion = nn.CrossEntropyLoss().cuda()
optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_decay_step, gamma=0.1)
best_acc1 = 0.0
# start train
logger.info("start training")
for epoch in range(args.epochs):
train_acc1, _, loss, = train(trainloader, model, criterion, optimizer, epoch, args, aug=args.dsa)
scheduler.step()
if (epoch + 1) % args.save_every == 0:
acc1, _ = validate(testloader, model, criterion, args)
logger.info(f"epoch:{epoch}, train acc:{train_acc1} loss: {loss} test acc:{acc1}")
is_best = acc1 > best_acc1
best_acc1 = max(acc1, best_acc1)
if args.save_model and is_best:
torch.save(model.state_dict(), os.path.join(save_root, save_name))
logger.info(f"dataset: {args.dataset}, method: {args.method}")
logger.info(f"best acc: {best_acc1}")
if args.save_model:
logger.info(f"save path: {os.path.join(save_root, save_name)}")
logger.removeHandler(file_handler)
logger.removeHandler(stream_handler)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="PyTorch Testing", epilog="End of Parameters")
parser.add_argument("--dataset", help="name of dataset", type=str, default='CIFAR10')
parser.add_argument("--model", metavar="ARCH", default='ConvNet', help="model architecture")
parser.add_argument('--method', type=str, default='DC', help='DC/DSA/DM')
parser.add_argument('--ipc', type=int, default=1, help='image(s) per class')
parser.add_argument("--random_seed", default=None, type=int, help="random seed")
parser.add_argument('--save_model', action="store_true", default=False, help='whether save model')
parser.add_argument("--multigpu", default=None, type=lambda x: [int(a) for a in x.split(",")],
help="Which GPUs to use for multigpu training")
parser.add_argument("--gpu", default=0, type=int, help="Which GPU to use for training")
parser.add_argument("--epochs", default=1000, type=int, metavar="N", help="number of total epochs to run")
parser.add_argument("--batch_size", default=256, type=int, help="batch size")
parser.add_argument("--lr", default=0.01, type=float, help="learning rate")
parser.add_argument('--momentum', type=float, default=0.9, help="SGD momentum(defalt: 0.9)")
parser.add_argument('--weight_decay', type=float, default=0.0005, help="SGD weight decay")
parser.add_argument("--save_every", default=50, type=int, help="how many epochs to save")
parser.add_argument('--lr_decay_step', default='500', type=str, help='learning rate decay step')
parser.add_argument('--normalize_data', action="store_true", default=False,
help='whether normalize dataset')
parser.add_argument('--dsa', action="store_true", default=False,
help='whether to use differentiable Siamese augmentation.')
parser.add_argument('--dsa_strategy', type=str, default='color_crop_cutout_flip_scale_rotate',
help='differentiable Siamese augmentation strategy')
parser.add_argument('--synthesis_data_path', type=str, default=None, help="synthesis dataset path")
parser.add_argument('--data_path', type=str, default=None, help='dataset path')
args = parser.parse_args()
main(args)