-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathcomput_score.py
84 lines (72 loc) · 3.81 KB
/
comput_score.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from os import path as osp
import json
import os
import torch
import argparse
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--input', type=str, default='saved_models/test_epochs_17.json')
parser.add_argument('--name', type=str, default='test')
parser.add_argument('--dataroot', type=str, default='../../SSL-VQA/data/vqacp2/cache')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
anno_path = osp.join(args.dataroot, '%s_target_count.pth'%(args.name))
annotations = torch.load(anno_path)
annotations = sorted(annotations, key=lambda x: x['question_id'])
print(annotations[0])
print(len(annotations))
predictions = sorted(json.load(open(args.input)), key=lambda x: x['question_id'])
score = 0
count = 0
other_score = 0
yes_no_score = 0
num_score = 0
yes_count = 0
other_count = 0
num_count = 0
upper_bound = 0
upper_bound_num = 0
upper_bound_yes_no = 0
upper_bound_other = 0
for pred, anno in zip(predictions, annotations):
if pred['question_id'] == anno['question_id']:
G_T= max(anno['answer_count'].values())
upper_bound += min(1, G_T / 3)
if pred['answer'] in anno['answers_word']:
proba = anno['answer_count'][pred['answer']]
score += min(1, proba / 3)
count +=1
if anno['answer_type'] == 'yes/no':
yes_no_score += min(1, proba / 3)
upper_bound_yes_no += min(1, G_T / 3)
yes_count +=1
if anno['answer_type'] == 'other':
other_score += min(1, proba / 3)
upper_bound_other += min(1, G_T / 3)
other_count +=1
if anno['answer_type'] == 'number':
num_score += min(1, proba / 3)
upper_bound_num += min(1, G_T / 3)
num_count +=1
else:
score += 0
yes_no_score +=0
other_score +=0
num_score +=0
if anno['answer_type'] == 'yes/no':
upper_bound_yes_no += min(1, G_T / 3)
yes_count +=1
if anno['answer_type'] == 'other':
upper_bound_other += min(1, G_T / 3)
other_count +=1
if anno['answer_type'] == 'number':
upper_bound_num += min(1, G_T / 3)
num_count +=1
print('count:', count, ' score:', round(score*100/len(annotations),2))
print('Yes/No:', round(100*yes_no_score/yes_count,2), 'Num:', round(100*num_score/num_count,2),
'other:', round(100*other_score/other_count,2))
print('count:', len(annotations), ' upper_bound:', round(score*upper_bound/len(annotations)),2)
print('upper_bound_Yes/No:', round(100*upper_bound_yes_no/yes_count,2), 'upper_bound_Num:',
round(100 * upper_bound_num/num_count,2), 'upper_bound_other:', round(100*upper_bound_other/other_count,2))