-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathtrain.py
105 lines (78 loc) · 3.49 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from datetime import datetime
import argparse
import os
import numpy as np
import cv2
import torch
from torch import nn, optim
from torch.utils import data as torch_data
from tensorboardX import SummaryWriter
from coolname import generate_slug
from spair.models import SPAIR
from spair import config as cfg
from spair.dataloader import SimpleScatteredMNISTDataset
from spair import debug_tools
from spair import metric
dt = datetime.today().strftime('%b-%d') + '-' + generate_slug(2)
run_log_path = 'logs_v2/%s' % dt
writer = SummaryWriter(run_log_path)
print('log path:', run_log_path)
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', help='Enable GPU use', action='store_true')
args = parser.parse_args()
if args.gpu:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
DEVICE = torch.device("cpu")
def train():
image_shape = cfg.INPUT_IMAGE_SHAPE
# Test image setup
data = SimpleScatteredMNISTDataset('spair/data/scattered_mnist_128x128_obj14x14.hdf5')
torch.manual_seed(3)
spair_net = SPAIR(image_shape, writer, DEVICE).to(DEVICE)
params = spair_net.parameters()
spair_optim = optim.Adam(params, lr=1e-4)
for epoch in range(100000):
dataloader = torch_data.DataLoader(data,
batch_size=cfg.BATCH_SIZE,
pin_memory=True,
num_workers= 1,
drop_last = True,
)
for batch_idx, batch in enumerate(dataloader):
x_image, y_bbox, y_digit_count = batch
iteration = epoch * len(dataloader) + batch_idx
x_image = x_image.to(DEVICE)
y_bbox = y_bbox.to(DEVICE)
y_digit_count = y_digit_count.to(DEVICE)
print('Iteration', iteration)
spair_optim.zero_grad()
loss, out_img, z_where, z_pres = spair_net(x_image, iteration)
loss.backward(retain_graph = True)
spair_optim.step()
# logging stuff
image_out = out_img[0]
image_in = x_image[0]
combined_image = torch.cat([image_in, image_out], dim=2)
writer.add_image('SPAIR input_output', combined_image, iteration)
# Log average precision metric every 5 step after 1000 iterations (when trainig_wheel is off)
if iteration > 1000 and iteration % 5 == 0: # iteration > 1000 and
meanAP = metric.mAP(z_where, z_pres, y_bbox, y_digit_count)
print('Bbox Average Precision:', meanAP.item())
writer.add_scalar('accuracy/bbox_average_precision', meanAP, iteration)
count_accuracy = metric.object_count_accuracy(z_pres, y_digit_count)
writer.add_scalar('accuracy/object_count_accuracy', count_accuracy, iteration)
# Save model
if iteration >= 1000 and iteration % 1000 == 0:
check_point_name = 'step_%d.pkl' % iteration
cp_dir = os.path.join(run_log_path, 'checkpoints')
os.makedirs(cp_dir, exist_ok=True)
save_path = os.path.join(run_log_path, 'checkpoints', check_point_name)
torch.save(spair_net.state_dict(), save_path)
print('=================\n\n')
torch.cuda.empty_cache()
child_nr = 0
for name, param in spair_net.named_children():
print(name)
if __name__ == '__main__':
train()