diff --git a/mpa/det/inferrer.py b/mpa/det/inferrer.py index 7d15f502..df512a59 100644 --- a/mpa/det/inferrer.py +++ b/mpa/det/inferrer.py @@ -1,6 +1,7 @@ # Copyright (C) 2022 Intel Corporation # SPDX-License-Identifier: Apache-2.0 # +from typing import List, Tuple import torch from mmcv.parallel import MMDataParallel, is_module_wrapper @@ -9,6 +10,7 @@ from mmdet.datasets import build_dataloader, build_dataset, replace_ImageToTensor from mmdet.models import build_detector from mmdet.parallel import MMDataCPU +from mmdet.utils.deployment import get_saliency_map, get_feature_vector from mpa.registry import STAGES from .stage import DetectionStage @@ -32,6 +34,8 @@ def run(self, model_cfg, model_ckpt, data_cfg, **kwargs): self._init_logger() mode = kwargs.get('mode', 'train') eval = kwargs.get('eval', False) + dump_features = kwargs.get('dump_features', False) + dump_saliency_map = kwargs.get('dump_saliency_map', False) if mode not in self.mode: return {} @@ -42,7 +46,8 @@ def run(self, model_cfg, model_ckpt, data_cfg, **kwargs): # mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir)) - outputs = self.infer(cfg, eval=eval) + outputs = self.infer(cfg, eval=eval, dump_features=dump_features, + dump_saliency_map=dump_saliency_map) # Save outputs # output_file_path = osp.join(cfg.work_dir, 'infer_result.npy') @@ -65,7 +70,7 @@ def default(self, obj): print(json_dump) """ - def infer(self, cfg, dump_features=False, eval=False): + def infer(self, cfg, eval=False, dump_features=False, dump_saliency_map=False): samples_per_gpu = cfg.data.test.pop('samples_per_gpu', 1) if samples_per_gpu > 1: # Replace 'ImageToTensor' to 'DefaultFormatBundle' @@ -140,26 +145,44 @@ def infer(self, cfg, dump_features=False, eval=False): # detections = single_gpu_test(model, data_loader) eval_predictions = [] feature_vectors = [] + saliency_maps = [] def dump_features_hook(mod, inp, out): with torch.no_grad(): - feature_map = out[-1] - feature_vector = torch.nn.functional.adaptive_avg_pool2d(feature_map, (1, 1)) + feature_vector = get_feature_vector(out) assert feature_vector.size(0) == 1 feature_vectors.append(feature_vector.view(-1).detach().cpu().numpy()) def dummy_dump_features_hook(mod, inp, out): feature_vectors.append(None) - hook = dump_features_hook if dump_features else dummy_dump_features_hook + def dump_saliency_hook(model: torch.nn.Module, input: Tuple, out: List[torch.Tensor]): + """ Dump the last feature map to `saliency_maps` cache + + Args: + model (torch.nn.Module): PyTorch model + input (Tuple): input + out (List[torch.Tensor]): a list of feature maps + """ + with torch.no_grad(): + saliency_map = get_saliency_map(out[-1]) + saliency_maps.append(saliency_map.squeeze(0).detach().cpu().numpy()) + + def dummy_dump_saliency_hook(model, input, out): + saliency_maps.append(None) + + feature_vector_hook = dump_features_hook if dump_features else dummy_dump_features_hook + saliency_map_hook = dump_saliency_hook if dump_saliency_map else dummy_dump_saliency_hook + # Use a single gpu for testing. Set in both mm_val_dataloader and eval_model if is_module_wrapper(model): model = model.module - with model.backbone.register_forward_hook(hook): - for data in data_loader: - with torch.no_grad(): - result = eval_model(return_loss=False, rescale=True, **data) - eval_predictions.extend(result) + with eval_model.module.backbone.register_forward_hook(feature_vector_hook): + with eval_model.module.backbone.register_forward_hook(saliency_map_hook): + for data in data_loader: + with torch.no_grad(): + result = eval_model(return_loss=False, rescale=True, **data) + eval_predictions.extend(result) for key in [ 'interval', 'tmpdir', 'start', 'gpu_collect', 'save_best', @@ -175,5 +198,7 @@ def dummy_dump_features_hook(mod, inp, out): classes=target_classes, detections=eval_predictions, metric=metric, + feature_vectors=feature_vectors, + saliency_maps=saliency_maps ) return outputs