-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy pathtest.py
85 lines (71 loc) · 2.93 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
"""
Test a model and compute detection mAP
Fred Zhang <frederic.zhang@anu.edu.au>
The Australian National University
Australian Centre for Robotic Vision
"""
import os
import torch
import argparse
import torchvision
from torch.utils.data import DataLoader
import pocket
from hicodet.hicodet import HICODet
from models import SpatiallyConditionedGraph as SCG
from utils import DataFactory, custom_collate, test
def main(args):
torch.cuda.set_device(0)
torch.backends.cudnn.benchmark = False
num_anno = torch.tensor(HICODet(None, anno_file=os.path.join(
args.data_root, 'instances_train2015.json')).anno_interaction)
rare = torch.nonzero(num_anno < 10).squeeze(1)
non_rare = torch.nonzero(num_anno >= 10).squeeze(1)
dataloader = DataLoader(
dataset=DataFactory(
name='hicodet', partition=args.partition,
data_root=args.data_root,
detection_root=args.detection_dir,
), collate_fn=custom_collate, batch_size=1,
num_workers=args.num_workers, pin_memory=True
)
net = SCG(
dataloader.dataset.dataset.object_to_verb, 49,
num_iterations=args.num_iter,
max_human=args.max_human,
max_object=args.max_object,
box_score_thresh=args.box_score_thresh
)
epoch = 0
if os.path.exists(args.model_path):
print("Loading model from ", args.model_path)
checkpoint = torch.load(args.model_path, map_location="cpu")
net.load_state_dict(checkpoint['model_state_dict'])
epoch = checkpoint["epoch"]
elif len(args.model_path):
print("\nWARNING: The given model path does not exist. "
"Proceed to use a randomly initialised model.\n")
net.cuda()
timer = pocket.utils.HandyTimer(maxlen=1)
with timer:
test_ap = test(net, dataloader)
print("Model at epoch: {} | time elapsed: {:.2f}s\n"
"Full: {:.4f}, rare: {:.4f}, non-rare: {:.4f}".format(
epoch, timer[0], test_ap.mean(),
test_ap[rare].mean(), test_ap[non_rare].mean()
))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Train an interaction head")
parser.add_argument('--data-root', default='hicodet', type=str)
parser.add_argument('--detection-dir', default='hicodet/detections/test2015',
type=str, help="Directory where detection files are stored")
parser.add_argument('--partition', default='test2015', type=str)
parser.add_argument('--num-iter', default=2, type=int,
help="Number of iterations to run message passing")
parser.add_argument('--box-score-thresh', default=0.2, type=float)
parser.add_argument('--max-human', default=15, type=int)
parser.add_argument('--max-object', default=15, type=int)
parser.add_argument('--num-workers', default=2, type=int)
parser.add_argument('--model-path', default='', type=str)
args = parser.parse_args()
print(args)
main(args)