Skip to content

Commit

Permalink
Update inst-seg model export to include feature vector and saliency m…
Browse files Browse the repository at this point in the history
…ap (#4053)

* Fix MaskRCNN/RTMDet-Inst/MaskRCNNTV Explain Mode
  • Loading branch information
eugene123tw authored Oct 25, 2024
1 parent 3eff132 commit 0d87ca6
Show file tree
Hide file tree
Showing 12 changed files with 163 additions and 58 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ All notable changes to this project will be documented in this file.
(https://github.com/openvinotoolkit/training_extensions/pull/3993)
- Update to work torch compile in detection
(https://github.com/openvinotoolkit/training_extensions/pull/4003)
- Fix MaskRCNN/RTMDet-Inst/MaskRCNNTV Explain Mode
(https://github.com/openvinotoolkit/training_extensions/pull/4053)

## \[2.3.0\]

Expand Down
16 changes: 15 additions & 1 deletion src/otx/algo/detection/detectors/single_stage_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@
import re
from typing import TYPE_CHECKING

import torch

from otx.algo.instance_segmentation.heads.rtmdet_inst_head import RTMDetInstSepBNHead
from otx.algo.modules.base_module import BaseModule
from otx.algo.utils.mmengine_utils import InstanceData
from otx.core.data.entity.detection import DetBatchDataEntity

if TYPE_CHECKING:
import torch
from torch import Tensor, nn


Expand Down Expand Up @@ -217,6 +219,18 @@ def export(
backbone_feat = self.extract_feat(batch_inputs)
bbox_head_feat = self.bbox_head.forward(backbone_feat)
feature_vector = self.feature_vector_fn(backbone_feat)
if isinstance(self.bbox_head, RTMDetInstSepBNHead):
# create dummy saliency map as its implemented in ModelAPI
saliency_map = torch.zeros(1)
bboxes, labels, masks = self.bbox_head.export(backbone_feat, batch_img_metas, rescale=rescale) # type: ignore[misc]
return {
"bboxes": bboxes,
"labels": labels,
"masks": masks,
"feature_vector": feature_vector,
"saliency_map": saliency_map,
}

saliency_map = self.explain_fn(bbox_head_feat[0])
bboxes, labels = self.bbox_head.export(backbone_feat, batch_img_metas, rescale=rescale)
return {
Expand Down
10 changes: 7 additions & 3 deletions src/otx/algo/explain/explain_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

if TYPE_CHECKING:
import numpy as np
from torch import Tensor

from otx.algo.utils.mmengine_utils import InstanceData

Expand All @@ -22,7 +23,10 @@

def feature_vector_fn(feature_map: FeatureMapType) -> torch.Tensor:
"""Generate the feature vector by average pooling feature maps."""
if isinstance(feature_map, (list, tuple)):
if isinstance(feature_map, (list, tuple, dict)):
if isinstance(feature_map, dict):
feature_map = list(feature_map.values())

# aggregate feature maps from Feature Pyramid Network
feature_vector = [
# Spatially pooling and flatten, B x C x H x W => B x C'
Expand Down Expand Up @@ -324,13 +328,13 @@ def __init__(self, num_classes: int) -> None:

def func(
self,
predictions: list[InstanceData],
predictions: list[InstanceData] | list[dict[str, Tensor]],
_: int = -1,
) -> list[np.array]:
"""Generate saliency maps from predicted masks by averaging and normalizing them per-class.
Args:
predictions (list[InstanceData]): Predictions of Instance Segmentation model.
predictions (list[InstanceData] | list[dict[str, Tensor]): Predictions of Instance Segmentation model.
Returns:
torch.Tensor: Class-wise Saliency Maps. One saliency map per each class - [batch, class_id, H, W]
Expand Down
6 changes: 4 additions & 2 deletions src/otx/algo/instance_segmentation/heads/rtmdet_inst_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -1094,8 +1094,10 @@ def _nms_with_mask_static(
mask_thr_binary (float): Binarization threshold for masks.
Returns:
tuple[Tensor, Tensor]: (dets, labels), `dets` of shape [N, num_det, 5]
and `labels` of shape [N, num_det].
tuple[Tensor, Tensor, Tensor]:
det (Tensor): The detection results of shape [N, num_boxes, 5].
labels (Tensor): The labels of shape [N, num_boxes].
masks (Tensor): The masks of shape [N, num_boxes, H, W].
"""
dets, labels, inds = multiclass_nms(
bboxes,
Expand Down
1 change: 1 addition & 0 deletions src/otx/algo/instance_segmentation/maskdino.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ def export(
self,
batch_inputs: Tensor,
batch_img_metas: list[dict],
explain_mode: bool = False,
) -> tuple[Tensor, Tensor, Tensor]:
"""Export the model."""
b, _, h, w = batch_inputs.size()
Expand Down
2 changes: 1 addition & 1 deletion src/otx/algo/instance_segmentation/maskrcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ def _exporter(self) -> OTXModelExporter:
"opset_version": 11,
"autograd_inlining": False,
},
output_names=["bboxes", "labels", "masks"],
output_names=["bboxes", "labels", "masks", "feature_vector", "saliency_map"] if self.explain_mode else None,
)

def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.") -> dict:
Expand Down
8 changes: 5 additions & 3 deletions src/otx/algo/instance_segmentation/maskrcnn_tv.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,9 @@ def _customize_outputs(
labels: list[torch.LongTensor] = []
masks: list[tv_tensors.Mask] = []

for img_info, prediction in zip(inputs.imgs_info, outputs):
# XAI wraps prediction under dictionary with key "predictions"
predictions = outputs["predictions"] if isinstance(outputs, dict) else outputs
for img_info, prediction in zip(inputs.imgs_info, predictions):
scores.append(prediction["scores"])
bboxes.append(
tv_tensors.BoundingBoxes(
Expand Down Expand Up @@ -220,7 +222,7 @@ def _exporter(self) -> OTXModelExporter:
"opset_version": 11,
"autograd_inlining": False,
},
output_names=["bboxes", "labels", "masks"],
output_names=["bboxes", "labels", "masks", "feature_vector", "saliency_map"] if self.explain_mode else None,
)

def forward_for_tracing(self, inputs: Tensor) -> tuple[Tensor, ...]:
Expand All @@ -230,4 +232,4 @@ def forward_for_tracing(self, inputs: Tensor) -> tuple[Tensor, ...]:
"image_shape": shape,
}
meta_info_list = [meta_info] * len(inputs)
return self.model.export(inputs, meta_info_list)
return self.model.export(inputs, meta_info_list, explain_mode=self.explain_mode)
64 changes: 46 additions & 18 deletions src/otx/algo/instance_segmentation/segmentors/maskrcnn_tv.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,22 +38,23 @@ def forward(self, entity: InstanceSegBatchDataEntity) -> dict[str, Tensor] | lis

image_list = ImageList(entity.images, img_shapes)
targets = []
for bboxes, labels, masks, polygons in zip(
entity.bboxes,
entity.labels,
entity.masks,
entity.polygons,
):
# NOTE: shift labels by 1 as 0 is reserved for background
_labels = labels + 1 if len(labels) else labels
targets.append(
{
"boxes": bboxes,
"labels": _labels,
"masks": masks,
"polygons": polygons,
},
)
if self.training:
for bboxes, labels, masks, polygons in zip(
entity.bboxes,
entity.labels,
entity.masks,
entity.polygons,
):
# NOTE: shift labels by 1 as 0 is reserved for background
_labels = labels + 1 if len(labels) else labels
targets.append(
{
"boxes": bboxes,
"labels": _labels,
"masks": masks,
"polygons": polygons,
},
)

features = self.backbone(image_list.tensors)
if isinstance(features, Tensor):
Expand Down Expand Up @@ -112,14 +113,41 @@ def export(
self,
batch_inputs: Tensor,
batch_img_metas: list[dict],
) -> tuple[list[Tensor], list[Tensor], list[Tensor]]:
"""Export the model with the given inputs and image metas."""
explain_mode: bool = False,
) -> tuple[list[Tensor], list[Tensor], list[Tensor]] | dict[str, Tensor | list[Tensor]]:
"""Export the model.
Args:
batch_inputs (Tensor): image input tensor.
batch_img_metas (list[dict]): image meta information.
explain_mode (bool, optional): export feature vector and saliency map. Defaults to False.
Returns:
tuple[list[Tensor], list[Tensor], list[Tensor]] | dict[str, Tensor]:
boxes (list[Tensor]): bounding boxes.
labels (list[Tensor]): labels.
masks_probs (list[Tensor]): masks probabilities.
feature_vector (Tensor, optional): feature vector.
saliency_map (Tensor, optional): saliency map.
"""
img_shapes = [img_meta["image_shape"] for img_meta in batch_img_metas]
image_list = ImageList(batch_inputs, img_shapes)
features = self.backbone(batch_inputs)
proposals, _ = self.rpn(image_list, features)
boxes, labels, masks_probs = self.roi_heads.export(features, proposals, image_list.image_sizes)
labels = [label - 1 for label in labels] # Convert back to 0-indexed labels

if explain_mode:
saliency_map = torch.zeros(1)
feature_vector = self.feature_vector_fn(features)
return {
"boxes": boxes,
"labels": labels,
"masks": masks_probs,
"feature_vector": feature_vector,
"saliency_map": saliency_map,
}

return boxes, labels, masks_probs


Expand Down
67 changes: 50 additions & 17 deletions src/otx/algo/instance_segmentation/segmentors/two_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,14 @@ class TwoStageDetector(nn.Module):
Two-stage detectors typically consisting of a region proposal network and a
task-specific regression head.
Args:
backbone (nn.Module): Module that extracts features from the input image.
neck (nn.Module): Module that further processes the features and optionally generates. (e.g FPN)
rpn_head (nn.Module): Region proposal network head.
roi_head (nn.Module): Region of interest head.
roi_criterion (nn.Module): Criterion to calculate ROI loss.
rpn_criterion (nn.Module): Criterion to calculate RPN loss.
"""

def __init__(
Expand Down Expand Up @@ -88,9 +96,9 @@ def with_bbox(self) -> bool:

def forward(
self,
entity: torch.Tensor,
entity: Tensor,
mode: str = "tensor",
) -> dict[str, torch.Tensor] | list[InstanceData] | tuple[torch.Tensor] | torch.Tensor:
) -> dict[str, Tensor] | list[InstanceData] | tuple[Tensor, ...] | Tensor:
"""The unified entry for a forward process in both training and test.
The method should accept three modes: "tensor", "predict" and "loss":
Expand All @@ -106,18 +114,14 @@ def forward(
parameter update, which are supposed to be done in :meth:`train_step`.
Args:
inputs (torch.Tensor): The input tensor with shape
(N, C, ...) in general.
data_samples (list[:obj:`DetDataSample`], optional): A batch of
data samples that contain annotations and predictions.
Defaults to None.
entity (Tensor): The input tensor with shape (N, C, ...) in general.
mode (str): Return what kind of value. Defaults to 'tensor'.
Returns:
The return type depends on ``mode``.
- If ``mode="tensor"``, return a tensor or a tuple of tensor.
- If ``mode="predict"``, return a list of :obj:`DetDataSample`.
- If ``mode="predict"``, return a list of :obj:`InstanceData`.
- If ``mode="loss"``, return a dict of tensor.
"""
if mode == "loss":
Expand All @@ -142,12 +146,11 @@ def extract_feat(self, batch_inputs: Tensor) -> tuple[Tensor]:
x = self.neck(x)
return x

def loss(self, batch_inputs: InstanceSegBatchDataEntity) -> dict:
def loss(self, batch_inputs: InstanceSegBatchDataEntity) -> dict[str, Tensor]:
"""Calculate losses from a batch of inputs and data samples.
Args:
batch_inputs (Tensor): Input images of shape (N, C, H, W).
These should usually be mean centered and std scaled.
batch_inputs (InstanceSegBatchDataEntity): The input data entity.
Returns:
dict: A dictionary of loss components
Expand Down Expand Up @@ -224,21 +227,51 @@ def predict(

def export(
self,
batch_inputs: torch.Tensor,
batch_inputs: Tensor,
batch_img_metas: list[dict],
) -> tuple[torch.Tensor, ...]:
"""Export for two stage detectors."""
x = self.extract_feat(batch_inputs)
explain_mode: bool = False,
) -> tuple[Tensor, Tensor, Tensor] | dict[str, Tensor]:
"""Export the model for ONNX/OpenVINO.
Args:
batch_inputs (Tensor): image tensor with shape (N, C, H, W).
batch_img_metas (list[dict]): image information.
explain_mode (bool, optional): whether to return feature vector. Defaults to False.
Returns:
dict[str, Tensor]: Return a dictionary when explain mode is ON containing the following items:
- bboxes (Tensor): bounding boxes.
- labels (Tensor): labels.
- masks (Tensor): masks.
- feature_vector (Tensor): feature vector.
- saliency_map (Tensor): dummy saliency map.
tuple[Tensor, Tensor, Tensor]: Return a tuple when explain mode is OFF containing the following items:
- bboxes (Tensor): bounding boxes.
- labels (Tensor): labels.
- masks (Tensor): masks.
"""
x = self.extract_feat(batch_inputs)
rpn_results_list = self.rpn_head.export(
x,
batch_img_metas,
rescale=False,
)

return self.roi_head.export(
bboxes, labels, masks = self.roi_head.export(
x,
rpn_results_list,
batch_img_metas,
rescale=False,
)

if explain_mode:
feature_vector = self.feature_vector_fn(x)
return {
"bboxes": bboxes,
"labels": labels,
"masks": masks,
"feature_vector": feature_vector,
# create dummy tensor as model API supports saliency_map
"saliency_map": torch.zeros(1),
}
return bboxes, labels, masks
Loading

0 comments on commit 0d87ca6

Please sign in to comment.