Skip to content
This repository has been archived by the owner on Apr 17, 2023. It is now read-only.

[Detection] Output saliency map and feature vector for MPA Det #24

Merged
merged 2 commits into from
Jul 7, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 35 additions & 10 deletions mpa/det/inferrer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
from typing import List, Tuple

import torch
from mmcv.parallel import MMDataParallel, is_module_wrapper
Expand All @@ -9,6 +10,7 @@
from mmdet.datasets import build_dataloader, build_dataset, replace_ImageToTensor
from mmdet.models import build_detector
from mmdet.parallel import MMDataCPU
from mmdet.utils.deployment import get_saliency_map, get_feature_vector

from mpa.registry import STAGES
from .stage import DetectionStage
Expand All @@ -32,6 +34,8 @@ def run(self, model_cfg, model_ckpt, data_cfg, **kwargs):
self._init_logger()
mode = kwargs.get('mode', 'train')
eval = kwargs.get('eval', False)
dump_features = kwargs.get('dump_features', False)
dump_saliency_map = kwargs.get('dump_saliency_map', False)
if mode not in self.mode:
return {}

Expand All @@ -42,7 +46,8 @@ def run(self, model_cfg, model_ckpt, data_cfg, **kwargs):

# mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))

outputs = self.infer(cfg, eval=eval)
outputs = self.infer(cfg, eval=eval, dump_features=dump_features,
dump_saliency_map=dump_saliency_map)

# Save outputs
# output_file_path = osp.join(cfg.work_dir, 'infer_result.npy')
Expand All @@ -65,7 +70,7 @@ def default(self, obj):
print(json_dump)
"""

def infer(self, cfg, dump_features=False, eval=False):
def infer(self, cfg, eval=False, dump_features=False, dump_saliency_map=False):
samples_per_gpu = cfg.data.test.pop('samples_per_gpu', 1)
if samples_per_gpu > 1:
# Replace 'ImageToTensor' to 'DefaultFormatBundle'
Expand Down Expand Up @@ -140,26 +145,44 @@ def infer(self, cfg, dump_features=False, eval=False):
# detections = single_gpu_test(model, data_loader)
eval_predictions = []
feature_vectors = []
saliency_maps = []

def dump_features_hook(mod, inp, out):
with torch.no_grad():
feature_map = out[-1]
feature_vector = torch.nn.functional.adaptive_avg_pool2d(feature_map, (1, 1))
feature_vector = get_feature_vector(out)
assert feature_vector.size(0) == 1
feature_vectors.append(feature_vector.view(-1).detach().cpu().numpy())

def dummy_dump_features_hook(mod, inp, out):
feature_vectors.append(None)

hook = dump_features_hook if dump_features else dummy_dump_features_hook
def dump_saliency_hook(model: torch.nn.Module, input: Tuple, out: List[torch.Tensor]):
""" Dump the last feature map to `saliency_maps` cache

Args:
model (torch.nn.Module): PyTorch model
input (Tuple): input
out (List[torch.Tensor]): a list of feature maps
"""
with torch.no_grad():
saliency_map = get_saliency_map(out[-1])
saliency_maps.append(saliency_map.squeeze(0).detach().cpu().numpy())

def dummy_dump_saliency_hook(model, input, out):
saliency_maps.append(None)

feature_vector_hook = dump_features_hook if dump_features else dummy_dump_features_hook
saliency_map_hook = dump_saliency_hook if dump_saliency_map else dummy_dump_saliency_hook

# Use a single gpu for testing. Set in both mm_val_dataloader and eval_model
if is_module_wrapper(model):
model = model.module
with model.backbone.register_forward_hook(hook):
for data in data_loader:
with torch.no_grad():
result = eval_model(return_loss=False, rescale=True, **data)
eval_predictions.extend(result)
with eval_model.module.backbone.register_forward_hook(feature_vector_hook):
with eval_model.module.backbone.register_forward_hook(saliency_map_hook):
for data in data_loader:
with torch.no_grad():
result = eval_model(return_loss=False, rescale=True, **data)
eval_predictions.extend(result)

for key in [
'interval', 'tmpdir', 'start', 'gpu_collect', 'save_best',
Expand All @@ -175,5 +198,7 @@ def dummy_dump_features_hook(mod, inp, out):
classes=target_classes,
detections=eval_predictions,
metric=metric,
feature_vectors=feature_vectors,
saliency_maps=saliency_maps
)
return outputs