Skip to content

Commit

Permalink
feat: add per class mAP and mAR to tensorboard
Browse files Browse the repository at this point in the history
  • Loading branch information
harmluSICKAG committed Jan 13, 2025
1 parent 0712f59 commit e4d4239
Showing 1 changed file with 34 additions and 0 deletions.
34 changes: 34 additions & 0 deletions src/solver/det_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,24 @@
from .det_engine import train_one_epoch, evaluate


def prepare_eval_metric(metric, catId, type):
# type precision: (iou, recall, cls, area range, max dets)
# type recall: (iou, cls, area range, max dets)
if type == "precision":
metric = [metric[i][j][catId][0][-1] for i in range(len(metric)) for j in range(len(metric[i]))]
elif type == "recall":
metric = [metric[i][catId][0][-1] for i in range(len(metric))]

# Filter out values <= -1
filtered_metric = [value for value in metric if value > -1]

# Calculate mean or return NaN if empty
if filtered_metric:
return sum(filtered_metric) / len(filtered_metric)
else:
return float("nan")


class DetSolver(BaseSolver):

def fit(self, ):
Expand Down Expand Up @@ -98,11 +116,27 @@ def fit(self, ):
self.device
)

coco_eval = coco_evaluator.coco_eval["bbox"]
precisions = coco_evaluator.coco_eval["bbox"].eval['precision']
recalls = coco_evaluator.coco_eval["bbox"].eval['recall']

class_results = {}
for category_id in coco_eval.cocoGt.getCatIds():
category_info = coco_eval.cocoGt.loadCats([category_id])[0]
category_name = category_info['name']
ap = prepare_eval_metric(precisions, category_id, "precision")
ar = prepare_eval_metric(recalls, category_id, "recall")
class_results[category_id] = {'name': category_name, 'mAP': ap, 'mAR': ar}

# TODO
for k in test_stats:
if self.writer and dist_utils.is_main_process():
for i, v in enumerate(test_stats[k]):
self.writer.add_scalar(f'Test/{k}_{i}'.format(k), v, epoch)

for class_id in class_results.keys():
self.writer.add_scalar(f'Test/class_{class_results[class_id]["name"]}_mAP', class_results[class_id]["mAP"], epoch)
self.writer.add_scalar(f'Test/class_{class_results[class_id]["name"]}_mAR', class_results[class_id]["mAR"], epoch)

if k in best_stat:
best_stat['epoch'] = epoch if test_stats[k][0] > best_stat[k] else best_stat['epoch']
Expand Down

0 comments on commit e4d4239

Please sign in to comment.