-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathtest_model.py
57 lines (42 loc) · 1.86 KB
/
test_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import os
import torch
from tqdm import tqdm
from config import opt
from object_place_dataset import get_test_dataloader
from object_place_net import ObjectPlaceNet
def F1(preds, gts):
tp = sum(list(map(lambda a, b: a == 1 and b == 1, preds, gts)))
fp = sum(list(map(lambda a, b: a == 1 and b == 0, preds, gts)))
fn = sum(list(map(lambda a, b: a == 0 and b == 1, preds, gts)))
tn = sum(list(map(lambda a, b: a == 0 and b == 0, preds, gts)))
tpr = tp / (tp + fn) if (tp + fn) > 0 else 0
tnr = tn / (tn + fp) if (tn + fp) > 0 else 0
f1 = (2 * tp) / (2 * tp + fp + fn) if (2 * tp + fp + fn) > 0 else 0
bal_acc = (tpr + tnr) / 2
return f1, bal_acc
def evaluate_model(device, checkpoint_path='./best-acc.pth'):
opt.without_mask = False
assert os.path.exists(checkpoint_path), checkpoint_path
net = ObjectPlaceNet(backbone_pretrained=False)
print('load pretrained weights from ', checkpoint_path)
net.load_state_dict(torch.load(checkpoint_path, map_location=device))
net = net.to(device).eval()
total = 0
pred_labels = []
gts = []
test_loader = get_test_dataloader()
with torch.no_grad():
for batch_index, (img_cat, label, target_box) in enumerate(tqdm(test_loader)):
img_cat, label, target_box = img_cat.to(
device), label.to(device), target_box.to(device)
logits = net(img_cat)
pred_labels.extend(logits.max(1)[1].cpu().numpy())
gts.extend(label.cpu().numpy())
total += label.size(0)
total_f1, total_bal_acc = F1(pred_labels, gts)
print("Baseline model evaluate on {} images, local:f1={:.4f},bal_acc={:.4f}".format(
total, total_f1, total_bal_acc))
return total_f1, total_bal_acc
if __name__ == '__main__':
device = "cuda:0"
f1, balanced_acc = evaluate_model(device, checkpoint_path='./best-acc.pth')