Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support activation map #2860

Merged
merged 16 commits into from
Feb 2, 2024
41 changes: 35 additions & 6 deletions src/otx/algo/hooks/recording_forward_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class BaseRecordingForwardHook:
normalize (bool): Whether to normalize the resulting saliency maps.
"""

def __init__(self, head_forward_fn: Callable, normalize: bool = True) -> None:
def __init__(self, head_forward_fn: Callable | None = None, normalize: bool = True) -> None:
self._head_forward_fn = head_forward_fn
self.handle: RemovableHandle | None = None
self._records: list[torch.Tensor] = []
Expand Down Expand Up @@ -70,10 +70,11 @@ def recording_forward(

def _predict_from_feature_map(self, x: torch.Tensor) -> torch.Tensor:
with torch.no_grad():
logits = self._head_forward_fn(x)
if not isinstance(logits, torch.Tensor):
logits = torch.tensor(logits)
return logits
if self._head_forward_fn:
x = self._head_forward_fn(x)
if not isinstance(x, torch.Tensor):
x = torch.tensor(x)
return x

def _torch_to_numpy_from_list(self, tensor_list: list[torch.Tensor | None]) -> None:
for i in range(len(tensor_list)):
Expand All @@ -97,6 +98,34 @@ def _normalize_map(saliency_maps: torch.Tensor) -> torch.Tensor:
return saliency_maps.to(torch.uint8)


class ActivationMapHook(BaseRecordingForwardHook):
"""ActivationMapHook. Mean of the feature map along the channel dimension."""

@classmethod
def create_and_register_hook(
cls,
backbone: torch.nn.Module,
) -> BaseRecordingForwardHook:
"""Create this object and register it to the module forward hook."""
hook = cls()
hook.handle = backbone.register_forward_hook(hook.recording_forward)
return hook

def func(self, feature_map: torch.Tensor | Sequence[torch.Tensor], fpn_idx: int = -1) -> torch.Tensor:
"""Generate the saliency map by average feature maps then normalizing to (0, 255)."""
if isinstance(feature_map, (list, tuple)):
feature_map = feature_map[fpn_idx]

batch_size, _, h, w = feature_map.size()
activation_map = torch.mean(feature_map, dim=1)

if self._norm_saliency_maps:
activation_map = activation_map.reshape((batch_size, h * w))
activation_map = self._normalize_map(activation_map)

return activation_map.reshape((batch_size, h, w))


class ReciproCAMHook(BaseRecordingForwardHook):
"""Implementation of Recipro-CAM for class-wise saliency map.

Expand Down Expand Up @@ -330,7 +359,7 @@ def func(
Returns:
torch.Tensor: Class-wise Saliency Maps. One saliency map per each class - [batch, class_id, H, W]
"""
cls_scores = self._head_forward_fn(feature_map)
cls_scores = self._head_forward_fn(feature_map) if self._head_forward_fn else feature_map

middle_idx = len(cls_scores) // 2
# resize to the middle feature map
Expand Down
48 changes: 48 additions & 0 deletions src/otx/algo/utils/xai_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
"""Utils used for XAI."""

from __future__ import annotations

from typing import TYPE_CHECKING

import cv2

from otx.core.config.explain import ExplainConfig

if TYPE_CHECKING:
from pathlib import Path


def get_processed_saliency_maps(
raw_saliency_maps: list,
explain_config: ExplainConfig | None,
predictions: list | None,
work_dir: Path | None,
) -> list:
"""Implement saliency map filtering and post-processing."""
if work_dir:
# Temporary saving saliency map for image 0, class 0 (for tests)
cv2.imwrite(str(work_dir / "saliency_map.tiff"), raw_saliency_maps[0][0])

selected_saliency_maps = select_saliency_maps(raw_saliency_maps, explain_config, predictions)
return process_saliency_maps(selected_saliency_maps, explain_config)


def select_saliency_maps(
negvet marked this conversation as resolved.
Show resolved Hide resolved
saliency_maps: list,
explain_config: ExplainConfig | None, # noqa: ARG001
predictions: list | None, # noqa: ARG001
) -> list:
"""Select saliency maps in accordance with TargetExplainGroup."""
# Implement <- TODO(negvet)
return saliency_maps


def process_saliency_maps(
saliency_maps: list,
explain_config: ExplainConfig | None, # noqa: ARG001
negvet marked this conversation as resolved.
Show resolved Hide resolved
) -> list:
"""Postptocess saliency maps."""
# Implement <- TODO(negvet)
return saliency_maps
16 changes: 10 additions & 6 deletions src/otx/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ def explain(
--checkpoint <CKPT_PATH, str>
```
"""
import cv2
from otx.algo.utils.xai_utils import get_processed_saliency_maps

ckpt_path = str(checkpoint) if checkpoint is not None else self.checkpoint
if explain_config is None:
Expand All @@ -466,16 +466,20 @@ def explain(

self._build_trainer(**kwargs)

self.trainer.predict(
predictions = self.trainer.predict(
model=lit_module,
datamodule=datamodule,
ckpt_path=ckpt_path,
)

# Optimize for memory <- TODO(negvet)
saliency_maps = self.trainer.model.model.explain_hook.records
# Temporary saving saliency map for image 0, class 0 (for tests)
cv2.imwrite(str(Path(self.work_dir) / "saliency_map.tiff"), saliency_maps[0][0])
return saliency_maps
raw_saliency_maps = self.trainer.model.model.explain_hook.records
return get_processed_saliency_maps(
raw_saliency_maps,
explain_config,
predictions,
Path(self.work_dir),
)

@classmethod
def from_config(cls, config_path: PathLike, data_root: PathLike | None = None, **kwargs) -> Engine:
Expand Down
22 changes: 22 additions & 0 deletions tests/integration/cli/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,9 @@ def test_otx_explain_e2e(
Returns:
None
"""
import cv2
import numpy as np

task = recipe.split("/")[-2]
model_name = recipe.split("/")[-1].split(".")[0]

Expand All @@ -207,12 +210,18 @@ def test_otx_explain_e2e(
"explain",
"--config",
recipe,
"--model.num_classes",
"1000",
"--data_root",
fxt_target_dataset_per_task[task],
"--engine.work_dir",
str(tmp_path_explain / "outputs"),
"--engine.device",
fxt_accelerator,
"--seed",
"0",
"--deterministic",
sovrasov marked this conversation as resolved.
Show resolved Hide resolved
"True",
*fxt_cli_override_command_per_task[task],
]

Expand All @@ -221,6 +230,19 @@ def test_otx_explain_e2e(

assert (tmp_path_explain / "outputs").exists()
assert (tmp_path_explain / "outputs" / "saliency_map.tiff").exists()
sal_map = cv2.imread(str(tmp_path_explain / "outputs" / "saliency_map.tiff"))
assert sal_map.shape[0] > 0
assert sal_map.shape[1] > 0

reference_sal_vals = {
negvet marked this conversation as resolved.
Show resolved Hide resolved
"multi_label_cls_efficientnet_v2_light": np.array([66, 97, 84, 33, 42, 79, 0], dtype=np.uint8),
"h_label_cls_efficientnet_v2_light": np.array([43, 84, 61, 5, 54, 31, 57], dtype=np.uint8),
}
test_case_name = task + "_" + model_name
if test_case_name in reference_sal_vals:
actual_sal_vals = sal_map[:, 0, 0]
ref_sal_vals = reference_sal_vals[test_case_name]
assert np.max(np.abs(actual_sal_vals - ref_sal_vals) <= 3)


@pytest.mark.parametrize("recipe", RECIPE_OV_LIST)
Expand Down
20 changes: 20 additions & 0 deletions tests/unit/algo/hooks/test_xai_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,32 @@
# SPDX-License-Identifier: Apache-2.0
import torch
from otx.algo.hooks.recording_forward_hook import (
ActivationMapHook,
DetClassProbabilityMapHook,
ReciproCAMHook,
ViTReciproCAMHook,
)


def test_activationmap() -> None:
hook = ActivationMapHook()

assert hook.handle is None
assert hook.records == []
assert hook._norm_saliency_maps

feature_map = torch.zeros((1, 10, 5, 5))

saliency_maps = hook.func(feature_map)
assert saliency_maps.size() == torch.Size([1, 5, 5])

hook.recording_forward(None, None, feature_map)
assert len(hook.records) == 1

hook.reset()
assert hook.records == []


def test_reciprocam() -> None:
def cls_head_forward_fn(_) -> None:
return torch.zeros((25, 2))
Expand Down
Loading