Skip to content

Commit

Permalink
Integrate new ignored loss in semantic segmentation (#2055)
Browse files Browse the repository at this point in the history
* HOT-FIX: Revert segmentation model's ignore mode in CLI (#2011)

Revert segmentation ignore=True

* Improve tiling preprocess (#2013)

* prevent timeout during init phase

* Fix reg tests (#2008)

* Edit regression tests

* Change the dataset root

* Miss typo

* Fix pre-commit

* Fix openvino import error due to Tiler init import (#2015)

Remove init import for Tiler to prevent OpenVINO import

* Bump up version to 1.2.0 (#2017)

* Set the python version to "3.10" for code-scan workflow

* Add missing __init__.py (#2019)

* Add missing __init__.py

* Change license

* Release 1.2.0rc1

* Fix issue that str2bool not being applied in certain cases (#2023)

* Add workaround solution

* Fix minor

* Remove str int

* Fix default dict (#2025)

fix: change default to configdict

Signed-off-by: Inhyuk Andy Cho <andy.inhyuk.jo@intel.com>

* Convert dummy datasets to toy datasets (#1988)

* Update cls, det datsets

* Remove useless files

* Change action datasets

* Edit action dataset

* change dir

* Add xml files

* Remove useless

* Edite tets

* Fix tests

* Fix tests

* Remove ptc

* Remove

* Fix precommit

* Update dataset, fix cls bug

* Remove useless dataset

* Edit drop_last

* Fix missed part

* Change threshold values to unifying

* bugfix: squeezing to 1 dimenetion

* Change threshold for deployment

* Fix multi-gpu issue, e2e tests

* Decrease num_workers for tiling test and tiling processes

* Revert num_workers for tests

* Fix datsets

---------

Co-authored-by: eunwoosh <eunwoo.shin@intel.com>

* Fix E2E tests (#2032)

* Optimize data preprocessing time and enhance overall performance in semantic segmentation (#2020)

* add unit test

* exp

* intg agg

* implement ignore label

* no update in model.py

* ignore in label

* only updated ignored loss

* final

* final

* exp

* refactor

* refactor ignore loss

* revert

* default mode is ignore false

* unit test revised

* model update

* revise detcon

* fix error in intg test

* test case revised

* make run in semisl

* add type hints

* update docs

* test case added

---------

Signed-off-by: Inhyuk Andy Cho <andy.inhyuk.jo@intel.com>
Co-authored-by: Harim Kang <harim.kang@intel.com>
Co-authored-by: Eugene Liu <eugene.liu@intel.com>
Co-authored-by: Sungman Cho <sungman.cho@intel.com>
Co-authored-by: Songki Choi <songki.choi@intel.com>
Co-authored-by: Yunchu Lee <yunchu.lee@intel.com>
Co-authored-by: Jaeguk Hyun <jaeguk.hyun@intel.com>
Co-authored-by: Inhyuk Cho <andy.inhyuk.jo@intel.com>
Co-authored-by: eunwoosh <eunwoo.shin@intel.com>
Co-authored-by: Lee, Soobee <soobeele@intel.com>
  • Loading branch information
10 people authored Apr 25, 2023
1 parent 5742c5f commit 919ac9e
Show file tree
Hide file tree
Showing 23 changed files with 371 additions and 117 deletions.
9 changes: 1 addition & 8 deletions otx/algorithms/segmentation/adapters/mmseg/configurer.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,10 +201,7 @@ def configure_ignore(self, cfg: Config) -> None:
)

if "decode_head" in cfg.model:
decode_head = cfg.model.decode_head
if decode_head.type == "FCNHead":
decode_head.type = "CustomFCNHead"
decode_head.loss_decode = cfg_loss_decode
cfg.model.decode_head.loss_decode = cfg_loss_decode

# pylint: disable=too-many-branches
def configure_classes(self, cfg: Config) -> None:
Expand Down Expand Up @@ -537,10 +534,6 @@ def configure_task(self, cfg: ConfigDict, training: bool, **kwargs: Any) -> None
"""Adjust settings for task adaptation."""
super().configure_task(cfg, training, **kwargs)

# Don't pass task_adapt arg to semi-segmentor
if cfg.model.type != "ClassIncrEncoderDecoder" and cfg.model.get("task_adapt", False):
cfg.model.pop("task_adapt")

# Remove task adapt hook (set default torch random sampler)
remove_custom_hook(cfg, "TaskAdaptHook")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,30 +4,158 @@
# SPDX-License-Identifier: Apache-2.0
#

from typing import Dict, List, Optional

import torch
from mmcv.runner import force_fp32
from mmseg.models.builder import HEADS
from mmseg.models.decode_heads.fcn_head import FCNHead
from mmseg.models.losses import accuracy
from mmseg.ops import resize
from torch import nn

from .mixin import (
AggregatorMixin,
MixLossMixin,
PixelWeightsMixin2,
SegmentOutNormMixin,
from otx.algorithms.segmentation.adapters.mmseg.models.utils import IterativeAggregator
from otx.algorithms.segmentation.adapters.mmseg.utils import (
get_valid_label_mask_per_batch,
)


@HEADS.register_module()
class CustomFCNHead(
SegmentOutNormMixin, AggregatorMixin, MixLossMixin, PixelWeightsMixin2, FCNHead
): # pylint: disable=too-many-ancestors
"""Custom Fully Convolution Networks for Semantic Segmentation."""
class CustomFCNHead(FCNHead):
"""Custom Fully Convolution Networks for Semantic Segmentation.
This FCN Head added head aggregator used in Lite-HRNet by
DepthwiseSeparableConvModule.
Please refer to https://github.com/HRNet/Lite-HRNet.
Args:
enable_aggregator (bool): If true, will use Lite-HRNet aggregator
concating all inputs from backbone by DepthwiseSeparableConvModule.
aggregator_min_channels (int, optional): The number of channels of output of aggregator.
It would work only if enable_aggregator is true.
aggregator_merge_norm (str, optional): normalize the output of expanders of aggregator.
options : "none", "channel", or None
aggregator_use_concat (str, optional): Whether to concat the last input
with the output of expanders.
"""

def __init__(
self,
enable_aggregator: bool = False,
aggregator_min_channels: Optional[int] = None,
aggregator_merge_norm: Optional[str] = None,
aggregator_use_concat: bool = False,
*args,
**kwargs
):

in_channels = kwargs.get("in_channels")
in_index = kwargs.get("in_index")
norm_cfg = kwargs.get("norm_cfg")
conv_cfg = kwargs.get("conv_cfg")
input_transform = kwargs.get("input_transform")

aggregator = None
if enable_aggregator: # Lite-HRNet aggregator
assert isinstance(in_channels, (tuple, list))
assert len(in_channels) > 1

aggregator = IterativeAggregator(
in_channels=in_channels,
min_channels=aggregator_min_channels,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
merge_norm=aggregator_merge_norm,
use_concat=aggregator_use_concat,
)

aggregator_min_channels = aggregator_min_channels if aggregator_min_channels is not None else 0
# change arguments temporarily
kwargs["in_channels"] = max(in_channels[0], aggregator_min_channels)
kwargs["input_transform"] = None
if in_index is not None:
kwargs["in_index"] = in_index[0]

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

self.aggregator = aggregator

# re-define variables
self.in_channels = in_channels
self.input_transform = input_transform
self.in_index = in_index

self.ignore_index = 255

# get rid of last activation of convs module
if self.act_cfg:
self.convs[-1].with_activation = False
delattr(self.convs[-1], "activate")

if kwargs.get("init_cfg", {}):
self.init_weights()

def _transform_inputs(self, inputs: torch.Tensor):
if self.aggregator is not None:
inputs = self.aggregator(inputs)[0]
else:
inputs = super()._transform_inputs(inputs)

return inputs

def forward_train(
self, inputs: torch.Tensor, img_metas: List[Dict], gt_semantic_seg: torch.Tensor, train_cfg: Dict
):
"""Forward function for training.
Args:
inputs (list[Tensor]): List of multi-level img features.
img_metas (list[dict]): List of image info dict where each dict
has: 'img_shape', 'scale_factor', 'flip', and may also contain
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
For details on the values of these keys see
`mmseg/datasets/pipelines/formatting.py:Collect`.
gt_semantic_seg (Tensor): Semantic segmentation masks
used if the architecture supports semantic segmentation task.
train_cfg (dict): The training config.
Returns:
dict[str, Tensor]: a dictionary of loss components
"""
seg_logits = self(inputs)
valid_label_mask = get_valid_label_mask_per_batch(img_metas, self.num_classes)
losses = self.losses(seg_logits, gt_semantic_seg, valid_label_mask=valid_label_mask)
return losses

@force_fp32(apply_to=("seg_logit",))
def losses(self, seg_logit: torch.Tensor, seg_label: torch.Tensor, valid_label_mask: Optional[torch.Tensor] = None):
"""Compute segmentation loss."""
loss = dict()

seg_logit = resize(input=seg_logit, size=seg_label.shape[2:], mode="bilinear", align_corners=self.align_corners)
if self.sampler is not None:
seg_weight = self.sampler.sample(seg_logit, seg_label)
else:
seg_weight = None
seg_label = seg_label.squeeze(1)

if not isinstance(self.loss_decode, nn.ModuleList):
losses_decode = [self.loss_decode]
else:
losses_decode = self.loss_decode
for loss_decode in losses_decode:
valid_label_mask_cfg = dict()
if loss_decode.loss_name == "loss_ce_ignore":
valid_label_mask_cfg["valid_label_mask"] = valid_label_mask
if loss_decode.loss_name not in loss:
loss[loss_decode.loss_name] = loss_decode(
seg_logit, seg_label, weight=seg_weight, ignore_index=self.ignore_index, **valid_label_mask_cfg
)
else:
loss[loss_decode.loss_name] += loss_decode(
seg_logit, seg_label, weight=seg_weight, ignore_index=self.ignore_index, **valid_label_mask_cfg
)

loss["acc_seg"] = accuracy(seg_logit, seg_label, ignore_index=self.ignore_index)

return loss
Original file line number Diff line number Diff line change
Expand Up @@ -3,61 +3,97 @@
# SPDX-License-Identifier: Apache-2.0
#

from typing import Optional

import torch
import torch.nn.functional as F
from mmseg.models.builder import LOSSES
from mmseg.models.losses.utils import get_class_weight

from .otx_pixel_base import OTXBasePixelLoss
from mmseg.models.losses import CrossEntropyLoss
from mmseg.models.losses.utils import weight_reduce_loss


@LOSSES.register_module()
class CrossEntropyLossWithIgnore(OTXBasePixelLoss):
class CrossEntropyLossWithIgnore(CrossEntropyLoss):
"""CrossEntropyLossWithIgnore with Ignore Mode Support for Class Incremental Learning.
Args:
model_classes (list[str]): Model classes
bg_aware (bool, optional): Whether to enable BG-aware loss
'background' class would be added the start of model classes/label schema
reduction (str, optional): . Defaults to 'mean'.
Options are "none", "mean" and "sum".
loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
When new classes are added through continual training cycles, images from previous cycles
may become partially annotated if they are not revisited.
To prevent the model from predicting these new classes for such images,
CrossEntropyLossWithIgnore can be used to ignore the unseen classes.
"""

def __init__(self, reduction="mean", loss_weight=None, **kwargs):
super().__init__(**kwargs)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._loss_name = "loss_ce_ignore"

def forward(
self,
cls_score: Optional[torch.Tensor],
label: Optional[torch.Tensor],
weight: Optional[torch.Tensor] = None,
avg_factor: Optional[int] = None,
reduction_override: Optional[str] = "mean",
ignore_index: int = 255,
valid_label_mask: Optional[torch.Tensor] = None,
**kwargs
):
"""Forward.
Args:
cls_score (torch.Tensor, optional): The prediction with shape (N, 1).
label (torch.Tensor, optional): The learning label of the prediction.
weight (torch.Tensor, optional): Sample-wise loss weight.
Default: None.
class_weight (list[float], optional): The weight for each class.
Default: None.
avg_factor (int, optional): Average factor that is used to average
the loss. Default: None.
reduction_override (str, optional): The method used to reduce the loss.
Options are 'none', 'mean' and 'sum'. Default: 'mean'.
ignore_index (int): Specifies a target value that is ignored and
does not contribute to the input gradients. When
``avg_non_ignore `` is ``True``, and the ``reduction`` is
``''mean''``, the loss is averaged over non-ignored targets.
Defaults: 255.
valid_label_mask (torch.Tensor, optional): The valid labels with
shape (N, num_classes).
If the value in the valid_label_mask is 0, mask label of the
the mask label of the class corresponding to its index will be
ignored like ignore_index.
**kwargs (Any): Additional keyword arguments.
"""
if valid_label_mask is None:
losses = super().forward(cls_score, label, weight, avg_factor, reduction_override, ignore_index, **kwargs)
return losses
else:
assert reduction_override in (None, "none", "mean", "sum")
reduction = reduction_override if reduction_override else self.reduction
batch_size = label.shape[0]
for i in range(batch_size):
invalid_labels = (valid_label_mask[i] == 0).nonzero(as_tuple=False)

for inv_l in invalid_labels:
label[i] = torch.where(label[i] == inv_l.item(), ignore_index, label[i])

self.reduction = reduction
self.class_weight = get_class_weight(loss_weight)
losses = F.cross_entropy(cls_score, label, reduction="none", ignore_index=ignore_index)

if weight is not None:
weight = weight.float()
losses = weight_reduce_loss(losses, weight=weight, reduction=reduction, avg_factor=avg_factor)

return losses

@property
def name(self):
"""name."""
return "ce_with_ignore"

def _calculate(self, cls_score, label, valid_label_mask, scale):
if cls_score.shape[0] == 0:
return torch.tensor(0.0)

batch_size = label.shape[0]
label = torch.from_numpy(label).to(cls_score.device)
probs_all = F.softmax(scale * cls_score, dim=1)
losses_l = []
for i in range(batch_size):
probs_gathered = probs_all[i, valid_label_mask[i] == 1]
probs_nomatch = probs_all[i, valid_label_mask[i] == 0]
probs_gathered = torch.unsqueeze(probs_gathered, 0)
probs_nomatch = torch.unsqueeze(probs_nomatch, 0)

probs_gathered[:, 0] += probs_nomatch.sum(dim=1)
each_prob_log = torch.log(probs_gathered)

# X-entropy: NLL loss w/ log-probabilities & labels
each_label = torch.unsqueeze(label[i], 0)
each_label = each_label.to(cls_score.device)
loss = F.nll_loss(each_prob_log, each_label, reduction="none", ignore_index=self.ignore_index)
losses_l.append(loss)

losses = torch.cat(losses_l, dim=0)

return losses, cls_score
def loss_name(self):
"""Loss Name.
This function must be implemented and will return the name of this
loss function. This name will be used to combine different loss items
by simple sum operation. In addition, if you want this loss item to be
included into the backward graph, `loss_` must be the prefix of the
name.
Returns:
str: The name of this loss item.
"""
return self._loss_name
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

from otx.algorithms.common.utils.logger import get_logger

from .class_incr_encoder_decoder import ClassIncrEncoderDecoder
from .class_incr_encoder_decoder import OTXEncoderDecoder

logger = get_logger()

Expand Down Expand Up @@ -442,7 +442,7 @@ def state_dict_hook(module, state_dict, *args, **kwargs):

# pylint: disable=too-many-locals
@SEGMENTORS.register_module()
class SupConDetConB(ClassIncrEncoderDecoder): # pylint: disable=too-many-ancestors
class SupConDetConB(OTXEncoderDecoder): # pylint: disable=too-many-ancestors
"""Apply DetConB as a contrastive part of `Supervised Contrastive Learning` (https://arxiv.org/abs/2004.11362).
SupCon with DetConB uses ground truth masks instead of pseudo masks to organize features among the same classes.
Expand Down Expand Up @@ -497,7 +497,6 @@ def forward_train(
img,
img_metas,
gt_semantic_seg,
pixel_weights=None,
**kwargs,
):
"""Forward function for training.
Expand All @@ -507,7 +506,6 @@ def forward_train(
img_metas (list[dict]): Input information.
gt_semantic_seg (Tensor): Ground truth masks.
It is used to organize features among the same classes.
pixel_weights (Tensor): Pixels weights.
**kwargs (Any): Addition keyword arguments.
Returns:
Expand All @@ -527,9 +525,7 @@ def forward_train(
img_metas += img_metas

# decode head
loss_decode, _ = self._decode_head_forward_train(
embds, img_metas, gt_semantic_seg=masks, pixel_weights=pixel_weights
)
loss_decode = self._decode_head_forward_train(embds, img_metas, gt_semantic_seg=masks)
losses.update(loss_decode)

return losses
Loading

0 comments on commit 919ac9e

Please sign in to comment.