Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add optional score threshold option to coco_error_analysis.py #11117

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 45 additions & 12 deletions tools/analysis_tools/coco_error_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,12 @@ def analyze_individual_category(k,
cocoEval.params.iouThrs = [0.1]
cocoEval.params.useCats = 1
if areas:
cocoEval.params.areaRng = [[0**2, areas[2]], [0**2, areas[0]],
[areas[0], areas[1]], [areas[1], areas[2]]]
cocoEval.params.areaRng = [
[0**2, areas[2]],
[0**2, areas[0]],
[areas[0], areas[1]],
[areas[1], areas[2]],
]
cocoEval.evaluate()
cocoEval.accumulate()
ps_supercategory = cocoEval.eval['precision'][0, :, k, :, :]
Expand All @@ -223,8 +227,12 @@ def analyze_individual_category(k,
cocoEval.params.iouThrs = [0.1]
cocoEval.params.useCats = 1
if areas:
cocoEval.params.areaRng = [[0**2, areas[2]], [0**2, areas[0]],
[areas[0], areas[1]], [areas[1], areas[2]]]
cocoEval.params.areaRng = [
[0**2, areas[2]],
[0**2, areas[0]],
[areas[0], areas[1]],
[areas[1], areas[2]],
]
cocoEval.evaluate()
cocoEval.accumulate()
ps_allcategory = cocoEval.eval['precision'][0, :, k, :, :]
Expand All @@ -237,13 +245,17 @@ def analyze_results(res_file,
res_types,
out_dir,
extraplots=None,
areas=None):
areas=None,
score_thr=None):
for res_type in res_types:
assert res_type in ['bbox', 'segm']
if areas:
assert len(areas) == 3, '3 integers should be specified as areas, \
assert (len(areas) == 3), '3 integers should be specified as areas, \
representing 3 area regions'

if score_thr:
assert score_thr >= 0, 'score_thr should be bigger than 0'

directory = os.path.dirname(out_dir + '/')
if not os.path.exists(directory):
print(f'-------------create {out_dir}-----------------')
Expand All @@ -252,6 +264,13 @@ def analyze_results(res_file,
cocoGt = COCO(ann_file)
cocoDt = cocoGt.loadRes(res_file)
imgIds = cocoGt.getImgIds()

if score_thr:
cocoDt.dataset['annotations'] = list(
filter(lambda ann: ann['score'] >= score_thr,
cocoDt.dataset['annotations']))
cocoDt.createIndex()

for res_type in res_types:
res_out_dir = out_dir + '/' + res_type + '/'
res_directory = os.path.dirname(res_out_dir)
Expand All @@ -265,9 +284,12 @@ def analyze_results(res_file,
cocoEval.params.iouThrs = [0.75, 0.5, 0.1]
cocoEval.params.maxDets = [100]
if areas:
cocoEval.params.areaRng = [[0**2, areas[2]], [0**2, areas[0]],
[areas[0], areas[1]],
[areas[1], areas[2]]]
cocoEval.params.areaRng = [
[0**2, areas[2]],
[0**2, areas[0]],
[areas[0], areas[1]],
[areas[1], areas[2]],
]
cocoEval.evaluate()
cocoEval.accumulate()
ps = cocoEval.eval['precision']
Expand Down Expand Up @@ -312,27 +334,38 @@ def main():
parser.add_argument(
'--ann',
default='data/coco/annotations/instances_val2017.json',
help='annotation file path')
help='annotation file path',
)
parser.add_argument(
'--types', type=str, nargs='+', default=['bbox'], help='result types')
parser.add_argument(
'--extraplots',
action='store_true',
help='export extra bar/stat plots')
parser.add_argument(
'--score-thr',
type=float,
default=None,
help='score threshold to filter detection bboxes, only applied'
'when users want to change it.',
)
parser.add_argument(
'--areas',
type=int,
nargs='+',
default=[1024, 9216, 10000000000],
help='area regions')
help='area regions',
)
args = parser.parse_args()
analyze_results(
args.result,
args.ann,
args.types,
out_dir=args.out_dir,
extraplots=args.extraplots,
areas=args.areas)
areas=args.areas,
score_thr=args.score_thr,
)


if __name__ == '__main__':
Expand Down