Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

XAI tor tiling: detection, instance segmentation #3297

Merged
Merged
5 changes: 4 additions & 1 deletion src/otx/algo/hooks/recording_forward_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,10 @@ def average_and_normalize(
Returns:
np.array: Class-wise Saliency Maps. One saliency map per each class - [class_id, H, W]
"""
masks, scores, labels = (pred.masks, pred.scores, pred.labels)
if isinstance(pred, dict):
masks, scores, labels = pred["masks"], pred["scores"], pred["labels"]
else:
masks, scores, labels = (pred.masks, pred.scores, pred.labels)
_, height, width = masks.shape

saliency_map = torch.zeros((num_classes, height, width), dtype=torch.float32, device=labels.device)
Expand Down
18 changes: 14 additions & 4 deletions src/otx/core/model/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from otx.core.config.data import TileConfig
from otx.core.data.entity.base import OTXBatchLossEntity
from otx.core.data.entity.detection import DetBatchDataEntity, DetBatchPredEntity
from otx.core.data.entity.tile import TileBatchDetDataEntity
from otx.core.data.entity.tile import OTXTileBatchDataEntity, TileBatchDetDataEntity
from otx.core.exporter.base import OTXModelExporter
from otx.core.metrics import MetricInput
from otx.core.metrics.mean_ap import MeanAPCallable
Expand Down Expand Up @@ -74,26 +74,33 @@ def forward_tiles(self, inputs: TileBatchDetDataEntity) -> DetBatchPredEntity:
tile_attrs: list[list[dict[str, int | str]]] = []
merger = DetectionTileMerge(
inputs.imgs_info,
self.tile_config.tile_size[0],
self.num_classes,
self.tile_config.iou_threshold,
self.tile_config.max_num_instances,
)
for batch_tile_attrs, batch_tile_input in inputs.unbind():
output = self.forward(batch_tile_input)
output = self.forward_explain(batch_tile_input) if self.explain_mode else self.forward(batch_tile_input)
if isinstance(output, OTXBatchLossEntity):
msg = "Loss output is not supported for tile merging"
raise TypeError(msg)
tile_preds.append(output)
tile_attrs.append(batch_tile_attrs)
pred_entities = merger.merge(tile_preds, tile_attrs)

return DetBatchPredEntity(
pred_entity = DetBatchPredEntity(
batch_size=inputs.batch_size,
images=[pred_entity.image for pred_entity in pred_entities],
imgs_info=[pred_entity.img_info for pred_entity in pred_entities],
scores=[pred_entity.score for pred_entity in pred_entities],
bboxes=[pred_entity.bboxes for pred_entity in pred_entities],
labels=[pred_entity.labels for pred_entity in pred_entities],
)
if self.explain_mode:
pred_entity.saliency_map = [pred_entity.saliency_map for pred_entity in pred_entities]
pred_entity.feature_vector = [pred_entity.feature_vector for pred_entity in pred_entities]

return pred_entity

@property
def _export_parameters(self) -> dict[str, Any]:
Expand Down Expand Up @@ -187,11 +194,14 @@ class ExplainableOTXDetModel(OTXDetectionModel):

def forward_explain(
self,
inputs: DetBatchDataEntity,
inputs: DetBatchDataEntity | TileBatchDetDataEntity,
) -> DetBatchPredEntity:
"""Model forward function."""
from otx.algo.hooks.recording_forward_hook import get_feature_vector

if isinstance(inputs, OTXTileBatchDataEntity):
return self.forward_tiles(inputs)

self.model.feature_vector_fn = get_feature_vector
self.model.explain_fn = self.get_explain_fn()

Expand Down
19 changes: 14 additions & 5 deletions src/otx/core/model/instance_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@
from openvino.model_api.tilers import InstanceSegmentationTiler
from torchvision import tv_tensors

from otx.algo.hooks.recording_forward_hook import get_feature_vector
from otx.core.config.data import TileConfig
from otx.core.data.entity.base import OTXBatchLossEntity
from otx.core.data.entity.instance_segmentation import InstanceSegBatchDataEntity, InstanceSegBatchPredEntity
from otx.core.data.entity.tile import TileBatchInstSegDataEntity
from otx.core.data.entity.tile import OTXTileBatchDataEntity, TileBatchInstSegDataEntity
from otx.core.exporter.base import OTXModelExporter
from otx.core.metrics import MetricInput
from otx.core.metrics.mean_ap import MaskRLEMeanAPCallable
Expand Down Expand Up @@ -83,19 +84,21 @@ def forward_tiles(self, inputs: TileBatchInstSegDataEntity) -> InstanceSegBatchP
tile_attrs: list[list[dict[str, int | str]]] = []
merger = InstanceSegTileMerge(
inputs.imgs_info,
self.tile_config.tile_size[0],
self.num_classes,
self.tile_config.iou_threshold,
self.tile_config.max_num_instances,
)
for batch_tile_attrs, batch_tile_input in inputs.unbind():
output = self.forward(batch_tile_input)
output = self.forward_explain(batch_tile_input) if self.explain_mode else self.forward(batch_tile_input)
if isinstance(output, OTXBatchLossEntity):
msg = "Loss output is not supported for tile merging"
raise TypeError(msg)
tile_preds.append(output)
tile_attrs.append(batch_tile_attrs)
pred_entities = merger.merge(tile_preds, tile_attrs)

return InstanceSegBatchPredEntity(
pred_entity = InstanceSegBatchPredEntity(
batch_size=inputs.batch_size,
images=[pred_entity.image for pred_entity in pred_entities],
imgs_info=[pred_entity.img_info for pred_entity in pred_entities],
Expand All @@ -105,6 +108,11 @@ def forward_tiles(self, inputs: TileBatchInstSegDataEntity) -> InstanceSegBatchP
masks=[pred_entity.masks for pred_entity in pred_entities],
polygons=[pred_entity.polygons for pred_entity in pred_entities],
)
if self.explain_mode:
pred_entity.saliency_map = [pred_entity.saliency_map for pred_entity in pred_entities]
pred_entity.feature_vector = [pred_entity.feature_vector for pred_entity in pred_entities]

return pred_entity

@property
def _export_parameters(self) -> dict[str, Any]:
Expand Down Expand Up @@ -237,10 +245,11 @@ class ExplainableOTXInstanceSegModel(OTXInstanceSegModel):

def forward_explain(
self,
inputs: InstanceSegBatchDataEntity,
inputs: InstanceSegBatchDataEntity | TileBatchInstSegDataEntity,
) -> InstanceSegBatchPredEntity:
"""Model forward function."""
from otx.algo.hooks.recording_forward_hook import get_feature_vector
if isinstance(inputs, OTXTileBatchDataEntity):
return self.forward_tiles(inputs)

self.model.feature_vector_fn = get_feature_vector
self.model.explain_fn = self.get_explain_fn()
Expand Down
Loading
Loading