diff --git a/otx/algorithms/segmentation/adapters/mmseg/configurer.py b/otx/algorithms/segmentation/adapters/mmseg/configurer.py index af41c11ea51..0fc24cd5f49 100644 --- a/otx/algorithms/segmentation/adapters/mmseg/configurer.py +++ b/otx/algorithms/segmentation/adapters/mmseg/configurer.py @@ -201,10 +201,7 @@ def configure_ignore(self, cfg: Config) -> None: ) if "decode_head" in cfg.model: - decode_head = cfg.model.decode_head - if decode_head.type == "FCNHead": - decode_head.type = "CustomFCNHead" - decode_head.loss_decode = cfg_loss_decode + cfg.model.decode_head.loss_decode = cfg_loss_decode # pylint: disable=too-many-branches def configure_classes(self, cfg: Config) -> None: @@ -537,10 +534,6 @@ def configure_task(self, cfg: ConfigDict, training: bool, **kwargs: Any) -> None """Adjust settings for task adaptation.""" super().configure_task(cfg, training, **kwargs) - # Don't pass task_adapt arg to semi-segmentor - if cfg.model.type != "ClassIncrEncoderDecoder" and cfg.model.get("task_adapt", False): - cfg.model.pop("task_adapt") - # Remove task adapt hook (set default torch random sampler) remove_custom_hook(cfg, "TaskAdaptHook") diff --git a/otx/algorithms/segmentation/adapters/mmseg/models/heads/custom_fcn_head.py b/otx/algorithms/segmentation/adapters/mmseg/models/heads/custom_fcn_head.py index 20b3fb2039b..aa22e88c66f 100644 --- a/otx/algorithms/segmentation/adapters/mmseg/models/heads/custom_fcn_head.py +++ b/otx/algorithms/segmentation/adapters/mmseg/models/heads/custom_fcn_head.py @@ -4,26 +4,89 @@ # SPDX-License-Identifier: Apache-2.0 # +from typing import Dict, List, Optional + +import torch +from mmcv.runner import force_fp32 from mmseg.models.builder import HEADS from mmseg.models.decode_heads.fcn_head import FCNHead +from mmseg.models.losses import accuracy +from mmseg.ops import resize +from torch import nn -from .mixin import ( - AggregatorMixin, - MixLossMixin, - PixelWeightsMixin2, - SegmentOutNormMixin, +from otx.algorithms.segmentation.adapters.mmseg.models.utils import IterativeAggregator +from otx.algorithms.segmentation.adapters.mmseg.utils import ( + get_valid_label_mask_per_batch, ) @HEADS.register_module() -class CustomFCNHead( - SegmentOutNormMixin, AggregatorMixin, MixLossMixin, PixelWeightsMixin2, FCNHead -): # pylint: disable=too-many-ancestors - """Custom Fully Convolution Networks for Semantic Segmentation.""" +class CustomFCNHead(FCNHead): + """Custom Fully Convolution Networks for Semantic Segmentation. + + This FCN Head added head aggregator used in Lite-HRNet by + DepthwiseSeparableConvModule. + Please refer to https://github.com/HRNet/Lite-HRNet. + + Args: + enable_aggregator (bool): If true, will use Lite-HRNet aggregator + concating all inputs from backbone by DepthwiseSeparableConvModule. + aggregator_min_channels (int, optional): The number of channels of output of aggregator. + It would work only if enable_aggregator is true. + aggregator_merge_norm (str, optional): normalize the output of expanders of aggregator. + options : "none", "channel", or None + aggregator_use_concat (str, optional): Whether to concat the last input + with the output of expanders. + """ + + def __init__( + self, + enable_aggregator: bool = False, + aggregator_min_channels: Optional[int] = None, + aggregator_merge_norm: Optional[str] = None, + aggregator_use_concat: bool = False, + *args, + **kwargs + ): + + in_channels = kwargs.get("in_channels") + in_index = kwargs.get("in_index") + norm_cfg = kwargs.get("norm_cfg") + conv_cfg = kwargs.get("conv_cfg") + input_transform = kwargs.get("input_transform") + + aggregator = None + if enable_aggregator: # Lite-HRNet aggregator + assert isinstance(in_channels, (tuple, list)) + assert len(in_channels) > 1 + + aggregator = IterativeAggregator( + in_channels=in_channels, + min_channels=aggregator_min_channels, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + merge_norm=aggregator_merge_norm, + use_concat=aggregator_use_concat, + ) + + aggregator_min_channels = aggregator_min_channels if aggregator_min_channels is not None else 0 + # change arguments temporarily + kwargs["in_channels"] = max(in_channels[0], aggregator_min_channels) + kwargs["input_transform"] = None + if in_index is not None: + kwargs["in_index"] = in_index[0] - def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + self.aggregator = aggregator + + # re-define variables + self.in_channels = in_channels + self.input_transform = input_transform + self.in_index = in_index + + self.ignore_index = 255 + # get rid of last activation of convs module if self.act_cfg: self.convs[-1].with_activation = False @@ -31,3 +94,68 @@ def __init__(self, *args, **kwargs): if kwargs.get("init_cfg", {}): self.init_weights() + + def _transform_inputs(self, inputs: torch.Tensor): + if self.aggregator is not None: + inputs = self.aggregator(inputs)[0] + else: + inputs = super()._transform_inputs(inputs) + + return inputs + + def forward_train( + self, inputs: torch.Tensor, img_metas: List[Dict], gt_semantic_seg: torch.Tensor, train_cfg: Dict + ): + """Forward function for training. + + Args: + inputs (list[Tensor]): List of multi-level img features. + img_metas (list[dict]): List of image info dict where each dict + has: 'img_shape', 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + `mmseg/datasets/pipelines/formatting.py:Collect`. + gt_semantic_seg (Tensor): Semantic segmentation masks + used if the architecture supports semantic segmentation task. + train_cfg (dict): The training config. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + seg_logits = self(inputs) + valid_label_mask = get_valid_label_mask_per_batch(img_metas, self.num_classes) + losses = self.losses(seg_logits, gt_semantic_seg, valid_label_mask=valid_label_mask) + return losses + + @force_fp32(apply_to=("seg_logit",)) + def losses(self, seg_logit: torch.Tensor, seg_label: torch.Tensor, valid_label_mask: Optional[torch.Tensor] = None): + """Compute segmentation loss.""" + loss = dict() + + seg_logit = resize(input=seg_logit, size=seg_label.shape[2:], mode="bilinear", align_corners=self.align_corners) + if self.sampler is not None: + seg_weight = self.sampler.sample(seg_logit, seg_label) + else: + seg_weight = None + seg_label = seg_label.squeeze(1) + + if not isinstance(self.loss_decode, nn.ModuleList): + losses_decode = [self.loss_decode] + else: + losses_decode = self.loss_decode + for loss_decode in losses_decode: + valid_label_mask_cfg = dict() + if loss_decode.loss_name == "loss_ce_ignore": + valid_label_mask_cfg["valid_label_mask"] = valid_label_mask + if loss_decode.loss_name not in loss: + loss[loss_decode.loss_name] = loss_decode( + seg_logit, seg_label, weight=seg_weight, ignore_index=self.ignore_index, **valid_label_mask_cfg + ) + else: + loss[loss_decode.loss_name] += loss_decode( + seg_logit, seg_label, weight=seg_weight, ignore_index=self.ignore_index, **valid_label_mask_cfg + ) + + loss["acc_seg"] = accuracy(seg_logit, seg_label, ignore_index=self.ignore_index) + + return loss diff --git a/otx/algorithms/segmentation/adapters/mmseg/models/losses/cross_entropy_loss_with_ignore.py b/otx/algorithms/segmentation/adapters/mmseg/models/losses/cross_entropy_loss_with_ignore.py index 57e9c24c268..16c1f83718c 100644 --- a/otx/algorithms/segmentation/adapters/mmseg/models/losses/cross_entropy_loss_with_ignore.py +++ b/otx/algorithms/segmentation/adapters/mmseg/models/losses/cross_entropy_loss_with_ignore.py @@ -3,61 +3,97 @@ # SPDX-License-Identifier: Apache-2.0 # +from typing import Optional + import torch import torch.nn.functional as F from mmseg.models.builder import LOSSES -from mmseg.models.losses.utils import get_class_weight - -from .otx_pixel_base import OTXBasePixelLoss +from mmseg.models.losses import CrossEntropyLoss +from mmseg.models.losses.utils import weight_reduce_loss @LOSSES.register_module() -class CrossEntropyLossWithIgnore(OTXBasePixelLoss): +class CrossEntropyLossWithIgnore(CrossEntropyLoss): """CrossEntropyLossWithIgnore with Ignore Mode Support for Class Incremental Learning. - Args: - model_classes (list[str]): Model classes - bg_aware (bool, optional): Whether to enable BG-aware loss - 'background' class would be added the start of model classes/label schema - reduction (str, optional): . Defaults to 'mean'. - Options are "none", "mean" and "sum". - loss_weight (float, optional): Weight of the loss. Defaults to 1.0. + When new classes are added through continual training cycles, images from previous cycles + may become partially annotated if they are not revisited. + To prevent the model from predicting these new classes for such images, + CrossEntropyLossWithIgnore can be used to ignore the unseen classes. """ - def __init__(self, reduction="mean", loss_weight=None, **kwargs): - super().__init__(**kwargs) + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._loss_name = "loss_ce_ignore" + + def forward( + self, + cls_score: Optional[torch.Tensor], + label: Optional[torch.Tensor], + weight: Optional[torch.Tensor] = None, + avg_factor: Optional[int] = None, + reduction_override: Optional[str] = "mean", + ignore_index: int = 255, + valid_label_mask: Optional[torch.Tensor] = None, + **kwargs + ): + """Forward. + + Args: + cls_score (torch.Tensor, optional): The prediction with shape (N, 1). + label (torch.Tensor, optional): The learning label of the prediction. + weight (torch.Tensor, optional): Sample-wise loss weight. + Default: None. + class_weight (list[float], optional): The weight for each class. + Default: None. + avg_factor (int, optional): Average factor that is used to average + the loss. Default: None. + reduction_override (str, optional): The method used to reduce the loss. + Options are 'none', 'mean' and 'sum'. Default: 'mean'. + ignore_index (int): Specifies a target value that is ignored and + does not contribute to the input gradients. When + ``avg_non_ignore `` is ``True``, and the ``reduction`` is + ``''mean''``, the loss is averaged over non-ignored targets. + Defaults: 255. + valid_label_mask (torch.Tensor, optional): The valid labels with + shape (N, num_classes). + If the value in the valid_label_mask is 0, mask label of the + the mask label of the class corresponding to its index will be + ignored like ignore_index. + **kwargs (Any): Additional keyword arguments. + """ + if valid_label_mask is None: + losses = super().forward(cls_score, label, weight, avg_factor, reduction_override, ignore_index, **kwargs) + return losses + else: + assert reduction_override in (None, "none", "mean", "sum") + reduction = reduction_override if reduction_override else self.reduction + batch_size = label.shape[0] + for i in range(batch_size): + invalid_labels = (valid_label_mask[i] == 0).nonzero(as_tuple=False) + + for inv_l in invalid_labels: + label[i] = torch.where(label[i] == inv_l.item(), ignore_index, label[i]) - self.reduction = reduction - self.class_weight = get_class_weight(loss_weight) + losses = F.cross_entropy(cls_score, label, reduction="none", ignore_index=ignore_index) + + if weight is not None: + weight = weight.float() + losses = weight_reduce_loss(losses, weight=weight, reduction=reduction, avg_factor=avg_factor) + + return losses @property - def name(self): - """name.""" - return "ce_with_ignore" - - def _calculate(self, cls_score, label, valid_label_mask, scale): - if cls_score.shape[0] == 0: - return torch.tensor(0.0) - - batch_size = label.shape[0] - label = torch.from_numpy(label).to(cls_score.device) - probs_all = F.softmax(scale * cls_score, dim=1) - losses_l = [] - for i in range(batch_size): - probs_gathered = probs_all[i, valid_label_mask[i] == 1] - probs_nomatch = probs_all[i, valid_label_mask[i] == 0] - probs_gathered = torch.unsqueeze(probs_gathered, 0) - probs_nomatch = torch.unsqueeze(probs_nomatch, 0) - - probs_gathered[:, 0] += probs_nomatch.sum(dim=1) - each_prob_log = torch.log(probs_gathered) - - # X-entropy: NLL loss w/ log-probabilities & labels - each_label = torch.unsqueeze(label[i], 0) - each_label = each_label.to(cls_score.device) - loss = F.nll_loss(each_prob_log, each_label, reduction="none", ignore_index=self.ignore_index) - losses_l.append(loss) - - losses = torch.cat(losses_l, dim=0) - - return losses, cls_score + def loss_name(self): + """Loss Name. + + This function must be implemented and will return the name of this + loss function. This name will be used to combine different loss items + by simple sum operation. In addition, if you want this loss item to be + included into the backward graph, `loss_` must be the prefix of the + name. + + Returns: + str: The name of this loss item. + """ + return self._loss_name diff --git a/otx/algorithms/segmentation/adapters/mmseg/models/segmentors/detcon.py b/otx/algorithms/segmentation/adapters/mmseg/models/segmentors/detcon.py index 2dfb57f4ff3..c083537d08a 100644 --- a/otx/algorithms/segmentation/adapters/mmseg/models/segmentors/detcon.py +++ b/otx/algorithms/segmentation/adapters/mmseg/models/segmentors/detcon.py @@ -26,7 +26,7 @@ from otx.algorithms.common.utils.logger import get_logger -from .class_incr_encoder_decoder import ClassIncrEncoderDecoder +from .class_incr_encoder_decoder import OTXEncoderDecoder logger = get_logger() @@ -442,7 +442,7 @@ def state_dict_hook(module, state_dict, *args, **kwargs): # pylint: disable=too-many-locals @SEGMENTORS.register_module() -class SupConDetConB(ClassIncrEncoderDecoder): # pylint: disable=too-many-ancestors +class SupConDetConB(OTXEncoderDecoder): # pylint: disable=too-many-ancestors """Apply DetConB as a contrastive part of `Supervised Contrastive Learning` (https://arxiv.org/abs/2004.11362). SupCon with DetConB uses ground truth masks instead of pseudo masks to organize features among the same classes. @@ -497,7 +497,6 @@ def forward_train( img, img_metas, gt_semantic_seg, - pixel_weights=None, **kwargs, ): """Forward function for training. @@ -507,7 +506,6 @@ def forward_train( img_metas (list[dict]): Input information. gt_semantic_seg (Tensor): Ground truth masks. It is used to organize features among the same classes. - pixel_weights (Tensor): Pixels weights. **kwargs (Any): Addition keyword arguments. Returns: @@ -527,9 +525,7 @@ def forward_train( img_metas += img_metas # decode head - loss_decode, _ = self._decode_head_forward_train( - embds, img_metas, gt_semantic_seg=masks, pixel_weights=pixel_weights - ) + loss_decode = self._decode_head_forward_train(embds, img_metas, gt_semantic_seg=masks) losses.update(loss_decode) return losses diff --git a/otx/algorithms/segmentation/adapters/mmseg/models/segmentors/otx_encoder_decoder.py b/otx/algorithms/segmentation/adapters/mmseg/models/segmentors/otx_encoder_decoder.py index 7c6ae32049f..8b7766f29a4 100644 --- a/otx/algorithms/segmentation/adapters/mmseg/models/segmentors/otx_encoder_decoder.py +++ b/otx/algorithms/segmentation/adapters/mmseg/models/segmentors/otx_encoder_decoder.py @@ -3,11 +3,15 @@ # SPDX-License-Identifier: Apache-2.0 # +import functools + import torch from mmseg.models import SEGMENTORS from mmseg.models.segmentors.encoder_decoder import EncoderDecoder +from mmseg.utils import get_root_logger from otx.algorithms.common.adapters.mmdeploy.utils import is_mmdeploy_enabled +from otx.algorithms.common.utils.task_adapt import map_class_names # pylint: disable=unused-argument, line-too-long @@ -15,6 +19,21 @@ class OTXEncoderDecoder(EncoderDecoder): """OTX encoder decoder.""" + def __init__(self, task_adapt=None, *args, **kwargs): + super().__init__(*args, **kwargs) + + # Hook for class-sensitive weight loading + assert task_adapt is not None, "When using task_adapt, task_adapt must be set." + + self._register_load_state_dict_pre_hook( + functools.partial( + self.load_state_dict_pre_hook, + self, # model + task_adapt["dst_classes"], # model_classes + task_adapt["src_classes"], # chkpt_classes + ) + ) + def simple_test(self, img, img_meta, rescale=True, output_logits=False): """Simple test with single image.""" seg_logit = self.inference(img, img_meta, rescale) @@ -33,6 +52,41 @@ def simple_test(self, img, img_meta, rescale=True, output_logits=False): seg_pred = list(seg_pred) return seg_pred + @staticmethod + def load_state_dict_pre_hook( + model, model_classes, chkpt_classes, chkpt_dict, prefix, *args, **kwargs + ): # pylint: disable=too-many-locals, unused-argument + """Modify input state_dict according to class name matching before weight loading.""" + logger = get_root_logger("INFO") + logger.info(f"----------------- OTXEncoderDecoder.load_state_dict_pre_hook() called w/ prefix: {prefix}") + + # Dst to src mapping index + model_classes = list(model_classes) + chkpt_classes = list(chkpt_classes) + model2chkpt = map_class_names(model_classes, chkpt_classes) + logger.info(f"{chkpt_classes} -> {model_classes} ({model2chkpt})") + + model_dict = model.state_dict() + param_names = [ + "decode_head.conv_seg.weight", + "decode_head.conv_seg.bias", + ] + for model_name in param_names: + chkpt_name = prefix + model_name + if model_name not in model_dict or chkpt_name not in chkpt_dict: + logger.info(f"Skipping weight copy: {chkpt_name}") + continue + + # Mix weights + model_param = model_dict[model_name].clone() + chkpt_param = chkpt_dict[chkpt_name] + for model_key, c in enumerate(model2chkpt): + if c >= 0: + model_param[model_key].copy_(chkpt_param[c]) + + # Replace checkpoint weight by mixed weights + chkpt_dict[chkpt_name] = model_param + if is_mmdeploy_enabled(): from mmdeploy.core import FUNCTION_REWRITER diff --git a/otx/algorithms/segmentation/configs/ocr_lite_hrnet_18/model.py b/otx/algorithms/segmentation/configs/ocr_lite_hrnet_18/model.py index c802a39a80f..09828793f12 100644 --- a/otx/algorithms/segmentation/configs/ocr_lite_hrnet_18/model.py +++ b/otx/algorithms/segmentation/configs/ocr_lite_hrnet_18/model.py @@ -22,10 +22,10 @@ ] model = dict( - type="ClassIncrEncoderDecoder", + type="OTXEncoderDecoder", pretrained=None, decode_head=dict( - type="FCNHead", + type="CustomFCNHead", in_channels=[40, 80, 160, 320], in_index=[0, 1, 2, 3], input_transform="multiple_select", @@ -38,7 +38,6 @@ norm_cfg=dict(type="BN", requires_grad=True), align_corners=False, enable_aggregator=True, - enable_out_norm=False, loss_decode=[ dict( type="CrossEntropyLoss", diff --git a/otx/algorithms/segmentation/configs/ocr_lite_hrnet_18/semisl/model.py b/otx/algorithms/segmentation/configs/ocr_lite_hrnet_18/semisl/model.py index c73630e39f3..77d4e8627bb 100644 --- a/otx/algorithms/segmentation/configs/ocr_lite_hrnet_18/semisl/model.py +++ b/otx/algorithms/segmentation/configs/ocr_lite_hrnet_18/semisl/model.py @@ -29,7 +29,7 @@ test_cfg=dict(mode="whole", output_scale=5.0), pretrained=None, decode_head=dict( - type="FCNHead", + type="CustomFCNHead", in_channels=[40, 80, 160, 320], in_index=[0, 1, 2, 3], input_transform="multiple_select", @@ -42,7 +42,6 @@ norm_cfg=dict(type="BN", requires_grad=True), align_corners=False, enable_aggregator=True, - enable_out_norm=False, loss_decode=[ dict( type="CrossEntropyLoss", diff --git a/otx/algorithms/segmentation/configs/ocr_lite_hrnet_18_mod2/model.py b/otx/algorithms/segmentation/configs/ocr_lite_hrnet_18_mod2/model.py index 4e7ce1bd9e0..8d1e6214d7b 100644 --- a/otx/algorithms/segmentation/configs/ocr_lite_hrnet_18_mod2/model.py +++ b/otx/algorithms/segmentation/configs/ocr_lite_hrnet_18_mod2/model.py @@ -22,10 +22,10 @@ ] model = dict( - type="ClassIncrEncoderDecoder", + type="OTXEncoderDecoder", pretrained=None, decode_head=dict( - type="FCNHead", + type="CustomFCNHead", in_channels=[40, 80, 160, 320], in_index=[0, 1, 2, 3], input_transform="multiple_select", @@ -38,7 +38,6 @@ norm_cfg=dict(type="BN", requires_grad=True), align_corners=False, enable_aggregator=True, - enable_out_norm=False, loss_decode=[ dict( type="CrossEntropyLoss", diff --git a/otx/algorithms/segmentation/configs/ocr_lite_hrnet_18_mod2/semisl/model.py b/otx/algorithms/segmentation/configs/ocr_lite_hrnet_18_mod2/semisl/model.py index a6284b24042..2038a00165d 100644 --- a/otx/algorithms/segmentation/configs/ocr_lite_hrnet_18_mod2/semisl/model.py +++ b/otx/algorithms/segmentation/configs/ocr_lite_hrnet_18_mod2/semisl/model.py @@ -29,7 +29,7 @@ test_cfg=dict(mode="whole", output_scale=5.0), pretrained=None, decode_head=dict( - type="FCNHead", + type="CustomFCNHead", in_channels=[40, 80, 160, 320], in_index=[0, 1, 2, 3], input_transform="multiple_select", @@ -42,7 +42,6 @@ norm_cfg=dict(type="BN", requires_grad=True), align_corners=False, enable_aggregator=True, - enable_out_norm=False, loss_decode=[ dict( type="CrossEntropyLoss", diff --git a/otx/algorithms/segmentation/configs/ocr_lite_hrnet_18_mod2/supcon/model.py b/otx/algorithms/segmentation/configs/ocr_lite_hrnet_18_mod2/supcon/model.py index 009ed8fe4f9..d78b82d8116 100644 --- a/otx/algorithms/segmentation/configs/ocr_lite_hrnet_18_mod2/supcon/model.py +++ b/otx/algorithms/segmentation/configs/ocr_lite_hrnet_18_mod2/supcon/model.py @@ -54,7 +54,7 @@ loss_cfg=dict(type="DetConLoss", temperature=0.1), ), decode_head=dict( - type="FCNHead", + type="CustomFCNHead", in_channels=[40, 80, 160, 320], in_index=[0, 1, 2, 3], input_transform="multiple_select", @@ -67,7 +67,6 @@ norm_cfg=dict(type="BN", requires_grad=True), align_corners=False, enable_aggregator=True, - enable_out_norm=False, loss_decode=[ dict( type="CrossEntropyLoss", diff --git a/otx/algorithms/segmentation/configs/ocr_lite_hrnet_s_mod2/model.py b/otx/algorithms/segmentation/configs/ocr_lite_hrnet_s_mod2/model.py index 512af483f35..f1f99bfa375 100644 --- a/otx/algorithms/segmentation/configs/ocr_lite_hrnet_s_mod2/model.py +++ b/otx/algorithms/segmentation/configs/ocr_lite_hrnet_s_mod2/model.py @@ -22,10 +22,10 @@ ] model = dict( - type="ClassIncrEncoderDecoder", + type="OTXEncoderDecoder", pretrained=None, decode_head=dict( - type="FCNHead", + type="CustomFCNHead", in_channels=[60, 120, 240], in_index=[0, 1, 2], input_transform="multiple_select", @@ -40,8 +40,6 @@ enable_aggregator=True, aggregator_merge_norm=None, aggregator_use_concat=False, - enable_out_norm=False, - enable_loss_equalizer=True, loss_decode=[ dict( type="CrossEntropyLoss", @@ -49,12 +47,6 @@ loss_weight=1.0, ), ], - init_cfg=dict( - type="Normal", - mean=0, - std=0.01, - override=dict(name="conv_seg"), - ), ), ) diff --git a/otx/algorithms/segmentation/configs/ocr_lite_hrnet_s_mod2/semisl/model.py b/otx/algorithms/segmentation/configs/ocr_lite_hrnet_s_mod2/semisl/model.py index a06f6389b25..73438d849ca 100644 --- a/otx/algorithms/segmentation/configs/ocr_lite_hrnet_s_mod2/semisl/model.py +++ b/otx/algorithms/segmentation/configs/ocr_lite_hrnet_s_mod2/semisl/model.py @@ -30,7 +30,7 @@ test_cfg=dict(mode="whole", output_scale=5.0), pretrained=None, decode_head=dict( - type="FCNHead", + type="CustomFCNHead", in_channels=[60, 120, 240], in_index=[0, 1, 2], input_transform="multiple_select", @@ -45,8 +45,6 @@ enable_aggregator=True, aggregator_merge_norm=None, aggregator_use_concat=False, - enable_out_norm=False, - enable_loss_equalizer=True, loss_decode=[ dict( type="CrossEntropyLoss", diff --git a/otx/algorithms/segmentation/configs/ocr_lite_hrnet_s_mod2/supcon/model.py b/otx/algorithms/segmentation/configs/ocr_lite_hrnet_s_mod2/supcon/model.py index eb025a041de..04fd2fca351 100644 --- a/otx/algorithms/segmentation/configs/ocr_lite_hrnet_s_mod2/supcon/model.py +++ b/otx/algorithms/segmentation/configs/ocr_lite_hrnet_s_mod2/supcon/model.py @@ -54,7 +54,7 @@ loss_cfg=dict(type="DetConLoss", temperature=0.1), ), decode_head=dict( - type="FCNHead", + type="CustomFCNHead", in_channels=[60, 120, 240], in_index=[0, 1, 2], input_transform="multiple_select", @@ -69,8 +69,6 @@ enable_aggregator=True, aggregator_merge_norm=None, aggregator_use_concat=False, - enable_out_norm=False, - enable_loss_equalizer=True, loss_decode=[ dict( type="CrossEntropyLoss", diff --git a/otx/algorithms/segmentation/configs/ocr_lite_hrnet_x_mod3/model.py b/otx/algorithms/segmentation/configs/ocr_lite_hrnet_x_mod3/model.py index 2300b99fcde..f8476d9c4d6 100644 --- a/otx/algorithms/segmentation/configs/ocr_lite_hrnet_x_mod3/model.py +++ b/otx/algorithms/segmentation/configs/ocr_lite_hrnet_x_mod3/model.py @@ -22,10 +22,10 @@ ] model = dict( - type="ClassIncrEncoderDecoder", + type="OTXEncoderDecoder", pretrained=None, decode_head=dict( - type="FCNHead", + type="CustomFCNHead", in_channels=[18, 60, 80, 160, 320], in_index=[0, 1, 2, 3, 4], input_transform="multiple_select", @@ -41,8 +41,6 @@ aggregator_min_channels=60, aggregator_merge_norm=None, aggregator_use_concat=False, - enable_out_norm=False, - enable_loss_equalizer=True, loss_decode=[ dict( type="CrossEntropyLoss", diff --git a/otx/algorithms/segmentation/configs/ocr_lite_hrnet_x_mod3/semisl/model.py b/otx/algorithms/segmentation/configs/ocr_lite_hrnet_x_mod3/semisl/model.py index 3925cba64ed..8e29a5e9111 100644 --- a/otx/algorithms/segmentation/configs/ocr_lite_hrnet_x_mod3/semisl/model.py +++ b/otx/algorithms/segmentation/configs/ocr_lite_hrnet_x_mod3/semisl/model.py @@ -30,7 +30,7 @@ test_cfg=dict(mode="whole", output_scale=5.0), pretrained=None, decode_head=dict( - type="FCNHead", + type="CustomFCNHead", in_channels=[18, 60, 80, 160, 320], in_index=[0, 1, 2, 3, 4], input_transform="multiple_select", @@ -46,8 +46,6 @@ aggregator_min_channels=60, aggregator_merge_norm=None, aggregator_use_concat=False, - enable_out_norm=False, - enable_loss_equalizer=True, loss_decode=[ dict( type="CrossEntropyLoss", diff --git a/otx/algorithms/segmentation/configs/ocr_lite_hrnet_x_mod3/supcon/model.py b/otx/algorithms/segmentation/configs/ocr_lite_hrnet_x_mod3/supcon/model.py index c6878a5d56b..8622b251ff6 100644 --- a/otx/algorithms/segmentation/configs/ocr_lite_hrnet_x_mod3/supcon/model.py +++ b/otx/algorithms/segmentation/configs/ocr_lite_hrnet_x_mod3/supcon/model.py @@ -54,7 +54,7 @@ loss_cfg=dict(type="DetConLoss", temperature=0.1), ), decode_head=dict( - type="FCNHead", + type="CustomFCNHead", in_channels=[18, 60, 80, 160, 320], in_index=[0, 1, 2, 3, 4], input_transform="multiple_select", @@ -70,8 +70,6 @@ aggregator_min_channels=60, aggregator_merge_norm=None, aggregator_use_concat=False, - enable_out_norm=False, - enable_loss_equalizer=True, loss_decode=[ dict( type="CrossEntropyLoss", diff --git a/otx/algorithms/segmentation/task.py b/otx/algorithms/segmentation/task.py index 82dd39b2799..da59baa3bfb 100644 --- a/otx/algorithms/segmentation/task.py +++ b/otx/algorithms/segmentation/task.py @@ -273,7 +273,6 @@ def evaluate( def _add_predictions_to_dataset(self, prediction_results, dataset, dump_soft_prediction): """Loop over dataset again to assign predictions. Convert from MMSegmentation format to OTX format.""" - for dataset_item, (prediction, feature_vector) in zip(dataset, prediction_results): soft_prediction = np.transpose(prediction[0], axes=(1, 2, 0)) hard_prediction = create_hard_prediction_from_soft_prediction( diff --git a/otx/cli/manager/config_manager.py b/otx/cli/manager/config_manager.py index 1b1b4e0e465..8fdc5a48c9f 100644 --- a/otx/cli/manager/config_manager.py +++ b/otx/cli/manager/config_manager.py @@ -530,7 +530,7 @@ def _copy_config_files(self, target_dir: Path, file_name: str, dest_dir: Path) - # FIXME: In the CLI, there is currently no case for using the ignore label. # so the workspace's model patches ignore to False. # FIXME: Segmentation -> ignore=True - if config.get("ignore", None) and str(self.task_type).upper() not in ("SEGMENTATION"): + if config.get("ignore", None): config.ignore = False print("In the CLI, Update ignore to false in model configuration.") config.dump(str(dest_dir / file_name)) diff --git a/otx/recipes/stages/_base_/models/segmentors/seg_class_incr.py b/otx/recipes/stages/_base_/models/segmentors/seg_class_incr.py index 02de5962ee2..12e743c3559 100644 --- a/otx/recipes/stages/_base_/models/segmentors/seg_class_incr.py +++ b/otx/recipes/stages/_base_/models/segmentors/seg_class_incr.py @@ -2,21 +2,21 @@ __norm_cfg = dict(type="BN", requires_grad=True) model = dict( - type="ClassIncrEncoderDecoder", + type="OTXEncoderDecoder", + pretrained=None, decode_head=dict( - type="FCNHead", + type="CustomFCNHead", in_channels=40, in_index=0, channels=40, input_transform=None, kernel_size=1, - num_convs=0, + num_convs=1, concat_input=False, dropout_ratio=-1, - num_classes=19, + num_classes=2, norm_cfg=__norm_cfg, align_corners=False, - enable_out_norm=False, loss_decode=[ dict( type="CrossEntropyLoss", diff --git a/tests/unit/algorithms/segmentation/adapters/mmseg/models/losses/test_cross_entropy_loss_with_ignore.py b/tests/unit/algorithms/segmentation/adapters/mmseg/models/losses/test_cross_entropy_loss_with_ignore.py new file mode 100644 index 00000000000..6350be9055f --- /dev/null +++ b/tests/unit/algorithms/segmentation/adapters/mmseg/models/losses/test_cross_entropy_loss_with_ignore.py @@ -0,0 +1,46 @@ +import pytest +import torch + +from otx.algorithms.segmentation.adapters.mmseg.models.losses.cross_entropy_loss_with_ignore import ( + CrossEntropyLossWithIgnore, +) +from mmseg.models.losses import CrossEntropyLoss + +from tests.test_suite.e2e_test_system import e2e_pytest_unit + + +class TestCrossEntropyLosWithIgnore: + @pytest.fixture(autouse=True) + def setup(self): + self.mock_score = torch.rand([1, 2, 5, 5]) + self.mock_gt = torch.zeros((1, 5, 5), dtype=torch.long) + self.mock_gt[::2, 1::2, :] = 1 + self.mock_gt[1::2, ::2, :] = 1 + + self.loss_f = CrossEntropyLossWithIgnore() + + @e2e_pytest_unit + def test_is_label_ignored(self): + loss = self.loss_f(self.mock_score, self.mock_gt, reduction_override="none") + assert type(loss) == torch.Tensor + + mock_valid_label_mask = torch.Tensor([[1, 0]]) + loss_ignore = self.loss_f( + self.mock_score, self.mock_gt, reduction_override="none", valid_label_mask=mock_valid_label_mask + ) + + assert torch.all(loss_ignore[::2, 1::2, :] == 0) + assert torch.all(loss_ignore[1::2, ::2, :] == 0) + + assert torch.equal(loss, loss_ignore) is False + + @e2e_pytest_unit + def test_is_equal_to_ce_loss(self): + loss_f_mmseg = CrossEntropyLoss() + + loss_1 = loss_f_mmseg(self.mock_score, self.mock_gt) + loss_2 = self.loss_f(self.mock_score, self.mock_gt) + loss_3 = self.loss_f(self.mock_score, self.mock_gt, valid_label_mask=torch.Tensor([1, 1])) + + assert loss_1 == loss_2 + assert loss_2 == loss_3 diff --git a/tests/unit/algorithms/segmentation/adapters/mmseg/models/segmentors/test_detcon.py b/tests/unit/algorithms/segmentation/adapters/mmseg/models/segmentors/test_detcon.py index 99e4f0af4c5..fca2da73d4a 100644 --- a/tests/unit/algorithms/segmentation/adapters/mmseg/models/segmentors/test_detcon.py +++ b/tests/unit/algorithms/segmentation/adapters/mmseg/models/segmentors/test_detcon.py @@ -88,7 +88,7 @@ def build_mock(mock_class, *args, **kwargs): ) mocker.patch( "otx.algorithms.segmentation.adapters.mmseg.models.segmentors.detcon.SupConDetConB._decode_head_forward_train", - return_value=(dict(loss=1.0), None), + return_value=dict(loss=1.0), ) diff --git a/tests/unit/algorithms/segmentation/adapters/mmseg/test_mmseg_configurer.py b/tests/unit/algorithms/segmentation/adapters/mmseg/test_mmseg_configurer.py index d347c87def3..6982a5c573e 100644 --- a/tests/unit/algorithms/segmentation/adapters/mmseg/test_mmseg_configurer.py +++ b/tests/unit/algorithms/segmentation/adapters/mmseg/test_mmseg_configurer.py @@ -241,7 +241,6 @@ def test_configure_task(self, mocker): model_cfg = ConfigDict(dict(model=dict(type="", task_adapt=True))) mock_remove_hook = mocker.patch("otx.algorithms.segmentation.adapters.mmseg.configurer.remove_custom_hook") self.configurer.configure_task(model_cfg, True) - assert "task_adapt" not in model_cfg.model mock_remove_hook.assert_called_once() @e2e_pytest_unit diff --git a/tests/unit/algorithms/segmentation/adapters/mmseg/utils/test_builder.py b/tests/unit/algorithms/segmentation/adapters/mmseg/utils/test_builder.py new file mode 100644 index 00000000000..29db15ff389 --- /dev/null +++ b/tests/unit/algorithms/segmentation/adapters/mmseg/utils/test_builder.py @@ -0,0 +1,26 @@ +import pytest +from otx.algorithms.segmentation.adapters.mmseg.utils import build_scalar_scheduler, build_segmentor +from tests.test_suite.e2e_test_system import e2e_pytest_unit + + +@e2e_pytest_unit +def test_build_scalar_scheduler(mocker): + cfg = mocker.MagicMock() + builder = mocker.patch("mmseg.models.builder.MODELS.build", return_value=True) + build_scalar_scheduler(cfg) + builder.assert_called_once_with(cfg) + + +@e2e_pytest_unit +def test_build_segmentor(mocker): + from mmcv.utils import Config + + cfg = Config({"model": {}, "load_from": "foo.pth"}) + mocker.patch("mmseg.models.build_segmentor") + load_ckpt = mocker.patch("otx.algorithms.segmentation.adapters.mmseg.utils.builder.load_checkpoint") + build_segmentor(cfg) + load_ckpt.assert_called() + + build_segmentor(cfg, is_training=True) + load_ckpt.assert_called() + assert cfg.load_from is None