-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathevaluate.py
116 lines (95 loc) · 3.13 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Created on Sat Jun 9 15:45:16 2019
@author: viswanatha
"""
from utils import *
from datasets import PascalVOCDataset
from tqdm import tqdm
from pprint import PrettyPrinter
pp = PrettyPrinter()
# Parameters
data_folder = "dataset"
keep_difficult = True # difficult ground truth objects must always be considered in mAP calculation, because these objects DO exist!
batch_size = 64
workers = 4
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
checkpoint = "./BEST_checkpoint_ssd300.pth.tar"
# Load model checkpoint that is to be evaluated
checkpoint = torch.load(checkpoint)
model = checkpoint["model"]
model = model.to(device)
# Switch to eval mode
model.eval()
# Load test data
test_dataset = PascalVOCDataset(
data_folder, split="test", keep_difficult=keep_difficult
)
test_loader = torch.utils.data.DataLoader(
test_dataset,
batch_size=batch_size,
shuffle=False,
collate_fn=test_dataset.collate_fn,
num_workers=workers,
pin_memory=True,
)
def evaluate(test_loader, model):
"""
Evaluate.
:param test_loader: DataLoader for test data
:param model: model
"""
# Make sure it's in eval mode
model.eval()
# Lists to store detected and true boxes, labels, scores
det_boxes, det_labels, det_scores, true_boxes, true_labels, true_difficulties = (
[],
[],
[],
[],
[],
[],
)
# it is necessary to know which objects are 'difficult', see 'calculate_mAP' in utils.py
with torch.no_grad():
# Batches
for i, (images, boxes, labels, difficulties) in enumerate(
tqdm(test_loader, desc="Evaluating")
):
images = images.to(device) # (N, 3, 300, 300)
# Forward prop.
predicted_locs, predicted_scores = model(images)
# Detect objects in SSD output
det_boxes_batch, det_labels_batch, det_scores_batch = detect_objects(
predicted_locs,
predicted_scores,
min_score=0.01,
max_overlap=0.45,
top_k=200,
)
# Evaluation MUST be at min_score=0.01, max_overlap=0.45, top_k=200 for fair comparision with the paper's results and other repos
# Store this batch's results for mAP calculation
boxes = [b.to(device) for b in boxes]
labels = [l.to(device) for l in labels]
difficulties = [d.to(device) for d in difficulties]
det_boxes.extend(det_boxes_batch)
det_labels.extend(det_labels_batch)
det_scores.extend(det_scores_batch)
true_boxes.extend(boxes)
true_labels.extend(labels)
true_difficulties.extend(difficulties)
# Calculate mAP
APs, mAP = calculate_mAP(
det_boxes,
det_labels,
det_scores,
true_boxes,
true_labels,
true_difficulties,
)
# Print AP for each class
pp.pprint(APs)
print("\nMean Average Precision (mAP): %.3f" % mAP)
if __name__ == "__main__":
evaluate(test_loader, model)