Skip to content

Commit

Permalink
Fix Dice score. Add mIoU metric for semantic segmentation (#3264)
Browse files Browse the repository at this point in the history
* fix mdice. Add mIoU

* rename Dice

* fix pre-commit. tests. Make ignore_index configurable

* reply comments
  • Loading branch information
kprokofi authored Apr 8, 2024
1 parent ce4b7fc commit d31cb33
Show file tree
Hide file tree
Showing 10 changed files with 88 additions and 24 deletions.
4 changes: 2 additions & 2 deletions src/otx/algo/segmentation/dino_v2_seg.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import TYPE_CHECKING, Any

from otx.algo.utils.mmconfig import read_mmconfig
from otx.core.metrics.dice import DiceCallable
from otx.core.metrics.dice import SegmCallable
from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable
from otx.core.model.segmentation import MMSegCompatibleModel
from otx.core.schedulers import LRSchedulerListCallable
Expand All @@ -26,7 +26,7 @@ def __init__(
num_classes: int,
optimizer: OptimizerCallable = DefaultOptimizerCallable,
scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable,
metric: MetricCallable = DiceCallable,
metric: MetricCallable = SegmCallable, # type: ignore[assignment]
torch_compile: bool = False,
) -> None:
model_name = "dino_v2_seg"
Expand Down
4 changes: 2 additions & 2 deletions src/otx/algo/segmentation/litehrnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from otx.algo.utils.mmconfig import read_mmconfig
from otx.algo.utils.support_otx_v1 import OTXv1Helper
from otx.core.metrics.dice import DiceCallable
from otx.core.metrics.dice import SegmCallable
from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable
from otx.core.model.segmentation import MMSegCompatibleModel
from otx.core.schedulers import LRSchedulerListCallable
Expand All @@ -31,7 +31,7 @@ def __init__(
variant: Literal["18", 18, "s", "x"],
optimizer: OptimizerCallable = DefaultOptimizerCallable,
scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable,
metric: MetricCallable = DiceCallable,
metric: MetricCallable = SegmCallable, # type: ignore[assignment]
torch_compile: bool = False,
) -> None:
self.model_name = f"litehrnet_{variant}"
Expand Down
4 changes: 2 additions & 2 deletions src/otx/algo/segmentation/segnext.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from otx.algo.utils.mmconfig import read_mmconfig
from otx.algo.utils.support_otx_v1 import OTXv1Helper
from otx.core.metrics.dice import DiceCallable
from otx.core.metrics.dice import SegmCallable
from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable
from otx.core.model.segmentation import MMSegCompatibleModel
from otx.core.schedulers import LRSchedulerListCallable
Expand All @@ -28,7 +28,7 @@ def __init__(
variant: Literal["b", "s", "t"],
optimizer: OptimizerCallable = DefaultOptimizerCallable,
scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable,
metric: MetricCallable = DiceCallable,
metric: MetricCallable = SegmCallable, # type: ignore[assignment]
torch_compile: bool = False,
) -> None:
model_name = f"segnext_{variant}"
Expand Down
6 changes: 1 addition & 5 deletions src/otx/core/data/dataset/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,21 +158,17 @@ def __init__(
self.label_info = SegLabelInfo(
label_names=self.label_info.label_names,
label_groups=self.label_info.label_groups,
ignore_index=ignore_index,
)
self.ignore_index = ignore_index

def _get_item_impl(self, index: int) -> SegDataEntity | None:
item = self.dm_subset.get(id=self.ids[index], subset=self.dm_subset.name)
img = item.media_as(Image)
num_classes = self.label_info.num_classes
ignored_labels: list[int] = []
img_data, img_shape = self._get_img_data_and_shape(img)

mask = _extract_class_mask(item=item, img_shape=img_shape, ignore_index=self.ignore_index)

# assign possible ignored labels from dataset to max label class + 1.
# it is needed to compute mDice metric.
mask[mask == 255] = num_classes
entity = SegDataEntity(
image=img_data,
img_info=ImageInfo(
Expand Down
64 changes: 60 additions & 4 deletions src/otx/core/metrics/dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,72 @@
# SPDX-License-Identifier: Apache-2.0
#
"""Module for OTX Dice metric used for the OTX semantic segmentation task."""
from __future__ import annotations

from typing import TYPE_CHECKING, Literal

from torchmetrics import JaccardIndex
from torchmetrics.classification.dice import Dice
from torchmetrics.collections import MetricCollection

from otx.core.types.label import LabelInfo
from otx.core.types.label import SegLabelInfo

if TYPE_CHECKING:
from torch import Tensor

def _dice_callable(label_info: LabelInfo) -> MetricCollection:

def _segm_callable(label_info: SegLabelInfo) -> MetricCollection:
return MetricCollection(
{"Dice": Dice(num_classes=label_info.num_classes + 1, ignore_index=label_info.num_classes)},
{
"Dice": OTXDice(num_classes=label_info.num_classes, ignore_index=label_info.ignore_index, average="macro"),
"mIoU": JaccardIndex(
task="multiclass",
num_classes=label_info.num_classes,
ignore_index=label_info.ignore_index,
),
},
)


DiceCallable = _dice_callable
class OTXDice(Dice):
"""Dice metric used for the OTX semantic segmentation task."""

def __init__(
self,
zero_division: int = 0,
num_classes: int | None = None,
threshold: float = 0.5,
average: Literal["micro", "macro", "none"] = "micro",
mdmc_average: str = "global",
ignore_index: int | None = None,
top_k: int | None = None,
multiclass: bool | None = None,
**kwargs,
) -> None:
super().__init__(
zero_division=zero_division,
num_classes=num_classes,
threshold=threshold,
average=average,
mdmc_average=mdmc_average,
ignore_index=None,
top_k=top_k,
multiclass=multiclass,
**kwargs,
)
# workaround to use ignore index > num_classes or < 0
self.extended_ignore_index = ignore_index

def update(self, preds: Tensor, target: Tensor) -> None:
"""Update state with predictions and targets. Fix ignore_index handling."""
if self.extended_ignore_index is not None:
filtered_preds = preds[target != self.extended_ignore_index]
filtered_target = target[target != self.extended_ignore_index]
else:
filtered_preds = preds
filtered_target = target

super().update(filtered_preds, filtered_target)


SegmCallable = _segm_callable
8 changes: 4 additions & 4 deletions src/otx/core/model/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from otx.core.exporter.base import OTXModelExporter
from otx.core.exporter.native import OTXNativeModelExporter
from otx.core.metrics import MetricInput
from otx.core.metrics.dice import DiceCallable
from otx.core.metrics.dice import SegmCallable
from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable, OTXModel, OVModel
from otx.core.schedulers import LRSchedulerListCallable
from otx.core.types.label import SegLabelInfo
Expand All @@ -42,7 +42,7 @@ def __init__(
num_classes: int,
optimizer: OptimizerCallable = DefaultOptimizerCallable,
scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable,
metric: MetricCallable = DiceCallable,
metric: MetricCallable = SegmCallable, # type: ignore[assignment]
torch_compile: bool = False,
):
super().__init__(
Expand Down Expand Up @@ -100,7 +100,7 @@ def __init__(
config: DictConfig,
optimizer: OptimizerCallable = DefaultOptimizerCallable,
scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable,
metric: MetricCallable = DiceCallable,
metric: MetricCallable = SegmCallable, # type: ignore[assignment]
torch_compile: bool = False,
) -> None:
config = inplace_num_classes(cfg=config, num_classes=num_classes)
Expand Down Expand Up @@ -233,7 +233,7 @@ def __init__(
max_num_requests: int | None = None,
use_throughput_mode: bool = True,
model_api_configuration: dict[str, Any] | None = None,
metric: MetricCallable = DiceCallable,
metric: MetricCallable = SegmCallable, # type: ignore[assignment]
**kwargs,
) -> None:
super().__init__(
Expand Down
3 changes: 2 additions & 1 deletion src/otx/core/types/label.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def from_json(cls, serialized: str) -> HLabelInfo:
class SegLabelInfo(LabelInfo):
"""Meta information of Semantic Segmentation."""

def __init__(self, label_names: list[str], label_groups: list[list[str]]) -> None:
def __init__(self, label_names: list[str], label_groups: list[list[str]], ignore_index: int = 255) -> None:
if not any(word.lower() == "background" for word in label_names):
msg = (
"Currently, no background label exists for `label_names`. "
Expand All @@ -281,6 +281,7 @@ def __init__(self, label_names: list[str], label_groups: list[list[str]]) -> Non
warnings.warn(msg, stacklevel=2)
label_names.insert(0, "Background")
super().__init__(label_names, label_groups)
self.ignore_index = ignore_index


@dataclass
Expand Down
4 changes: 4 additions & 0 deletions src/otx/recipe/_base_/data/mmseg_base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ config:
data_format: common_semantic_segmentation_with_subset_dirs
include_polygons: true
unannotated_items_ratio: 0.0
ignore_index: 255
train_subset:
subset_name: train
batch_size: 8
Expand Down Expand Up @@ -35,6 +36,9 @@ config:
size:
- 512
- 512
pad_val:
img: 0
seg: 255
- type: PackSegInputs
sampler:
class_path: torch.utils.data.RandomSampler
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def fxt_seg_data_entity() -> tuple[tuple, SegDataEntity, SegBatchDataEntity]:
img_size = (32, 32)
fake_image = torch.zeros(size=(3, *img_size), dtype=torch.uint8).numpy()
fake_image_info = ImageInfo(img_idx=0, img_shape=img_size, ori_shape=img_size)
fake_masks = Mask(torch.randint(low=0, high=255, size=img_size, dtype=torch.uint8))
fake_masks = Mask(torch.randint(low=0, high=2, size=img_size, dtype=torch.uint8))
# define data entity
single_data_entity = SegDataEntity(
image=fake_image,
Expand Down
13 changes: 10 additions & 3 deletions tests/unit/core/model/test_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from importlib_resources import files
from omegaconf import OmegaConf
from otx.core.model.segmentation import MMSegCompatibleModel
from otx.core.types.label import SegLabelInfo

if TYPE_CHECKING:
from omegaconf.dictconfig import DictConfig
Expand All @@ -25,7 +26,9 @@ def config(self) -> DictConfig:

@pytest.fixture()
def model(self, config) -> MMSegCompatibleModel:
return MMSegCompatibleModel(num_classes=2, config=config)
model = MMSegCompatibleModel(num_classes=2, config=config)
model.label_info = SegLabelInfo(label_names=["Background", "label_1"], label_groups=[["Background", "label_1"]])
return model

def test_create_model(self, model) -> None:
mmseg_model = model._create_model()
Expand Down Expand Up @@ -66,7 +69,9 @@ def test_validation_step(self, mocker, model, fxt_seg_data_entity) -> None:
mocker_update_loss = mocker.patch.object(
model,
"_convert_pred_entity_to_compute_metric",
return_value=[{"preds": torch.randn(size=[3, 3, 3]), "target": torch.randint(0, 2, size=[3, 3])}],
return_value=[
{"preds": torch.randint(0, 2, size=[1, 3, 3]), "target": torch.randint(0, 2, size=[1, 3, 3])},
],
)
model.validation_step(fxt_seg_data_entity[2], 0)
mocker_update_loss.assert_called_once()
Expand All @@ -77,7 +82,9 @@ def test_test_metric(self, mocker, model, fxt_seg_data_entity) -> None:
mocker_update_loss = mocker.patch.object(
model,
"_convert_pred_entity_to_compute_metric",
return_value=[{"preds": torch.randn(size=[3, 3, 3]), "target": torch.randint(0, 2, size=[3, 3])}],
return_value=[
{"preds": torch.randint(0, 2, size=[1, 3, 3]), "target": torch.randint(0, 2, size=[1, 3, 3])},
],
)
model.test_step(fxt_seg_data_entity[2], 0)
mocker_update_loss.assert_called_once()
Expand Down

0 comments on commit d31cb33

Please sign in to comment.