diff --git a/CHANGELOG.md b/CHANGELOG.md index 11ed67e7ae3..bc6ba91edc2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,7 +6,7 @@ All notable changes to this project will be documented in this file. ### New features -- Add RT-DETR model for object detection task +- Add RT-DETR model for Object Detection (https://github.com/openvinotoolkit/training_extensions/pull/3741) - Add Multi-Label & H-label Classification with torchvision models (https://github.com/openvinotoolkit/training_extensions/pull/3697) @@ -14,7 +14,7 @@ All notable changes to this project will be documented in this file. (https://github.com/openvinotoolkit/training_extensions/pull/3710) - Add LoRA finetuning capability for ViT Architectures (https://github.com/openvinotoolkit/training_extensions/pull/3729) -- Add Hugging-Face Model Wrapper for Detection +- Add Hugging-Face Model Wrapper for Object Detection (https://github.com/openvinotoolkit/training_extensions/pull/3747) - Add Hugging-Face Model Wrapper for Semantic Segmentation (https://github.com/openvinotoolkit/training_extensions/pull/3749) @@ -24,6 +24,8 @@ All notable changes to this project will be documented in this file. (https://github.com/openvinotoolkit/training_extensions/pull/3762) - Add RTMPose for Keypoint Detection Task (https://github.com/openvinotoolkit/training_extensions/pull/3781) +- Add Semi-SL MeanTeacher algorithm for Semantic Segmentation + (https://github.com/openvinotoolkit/training_extensions/pull/3801) - Update head and h-label format for hierarchical label classification (https://github.com/openvinotoolkit/training_extensions/pull/3810) diff --git a/src/otx/algo/callbacks/ema_mean_teacher.py b/src/otx/algo/callbacks/ema_mean_teacher.py new file mode 100644 index 00000000000..391ffb9067d --- /dev/null +++ b/src/otx/algo/callbacks/ema_mean_teacher.py @@ -0,0 +1,81 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +"""Module for exponential moving average for SemiSL mean teacher algorithm.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import torch +from lightning import Callback, LightningModule, Trainer + +if TYPE_CHECKING: + from lightning.pytorch.utilities.types import STEP_OUTPUT + + +class EMAMeanTeacher(Callback): + """callback for SemiSL MeanTeacher algorithm. + + This callback averages the weights of the teacher model. + + Args: + momentum (float, optional): momentum. Defaults to 0.999. + start_epoch (int, optional): start epoch. Defaults to 1. + """ + + def __init__( + self, + momentum: float = 0.999, + start_epoch: int = 1, + ) -> None: + super().__init__() + self.momentum = momentum + self.start_epoch = start_epoch + self.synced_models = False + + def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: + """Set up src & dst model parameters.""" + # call to nn.model + model = trainer.model.model + self.src_model = getattr(model, "student_model", None) + self.dst_model = getattr(model, "teacher_model", None) + if self.src_model is None or self.dst_model is None: + msg = "student_model and teacher_model should be set for MeanTeacher algorithm" + raise RuntimeError(msg) + self.src_params = self.src_model.state_dict(keep_vars=True) + self.dst_params = self.dst_model.state_dict(keep_vars=True) + + def on_train_batch_end( + self, + trainer: Trainer, + pl_module: LightningModule, + outputs: STEP_OUTPUT, + batch: Any, # noqa: ANN401 + batch_idx: int, + ) -> None: + """Update ema parameter every iteration.""" + if trainer.current_epoch < self.start_epoch: + return + + # EMA + self._ema_model(trainer.global_step) + + def _copy_model(self) -> None: + with torch.no_grad(): + for name, src_param in self.src_params.items(): + if src_param.requires_grad: + dst_param = self.dst_params[name] + dst_param.data.copy_(src_param.data) + + def _ema_model(self, global_step: int) -> None: + if self.start_epoch != 0 and not self.synced_models: + self._copy_model() + self.synced_models = True + + momentum = min(1 - 1 / (global_step + 1), self.momentum) + with torch.no_grad(): + for name, src_param in self.src_params.items(): + if src_param.requires_grad: + dst_param = self.dst_params[name] + dst_param.data.copy_(dst_param.data * momentum + src_param.data * (1 - momentum)) diff --git a/src/otx/algo/common/utils/utils.py b/src/otx/algo/common/utils/utils.py index 79e69aa5ffe..89b24e7d289 100644 --- a/src/otx/algo/common/utils/utils.py +++ b/src/otx/algo/common/utils/utils.py @@ -16,6 +16,7 @@ from functools import partial from typing import Callable +import numpy as np import torch import torch.distributed as dist from torch import Tensor @@ -259,3 +260,69 @@ def inverse_sigmoid(x: Tensor, eps: float = 1e-5) -> Tensor: x1 = x.clamp(min=eps) x2 = (1 - x).clamp(min=eps) return torch.log(x1 / x2) + + +def cut_mixer(images: Tensor, masks: Tensor) -> tuple[Tensor, Tensor]: + """Applies cut-mix augmentation to the input images and masks. + + Args: + images (Tensor): The input images tensor. + masks (Tensor): The input masks tensor. + + Returns: + tuple[Tensor, Tensor]: A tuple containing the augmented images and masks tensors. + """ + + def rand_bbox(size: tuple[int, ...], lam: float) -> tuple[list[int], ...]: + """Generates random bounding box coordinates. + + Args: + size (tuple[int, ...]): The size of the input tensor. + lam (float): The lambda value for cut-mix augmentation. + + Returns: + tuple[list[int, ...], ...]: The bounding box coordinates (bbx1, bby1, bbx2, bby2). + """ + # past implementation + w = size[2] + h = size[3] + b = size[0] + cut_rat = np.sqrt(1.0 - lam) + cut_w = int(w * cut_rat) + cut_h = int(h * cut_rat) + + cx = np.random.randint(size=[b], low=int(w / 8), high=w) + cy = np.random.randint(size=[b], low=int(h / 8), high=h) + + bbx1 = np.clip(cx - cut_w // 2, 0, w) + bby1 = np.clip(cy - cut_h // 2, 0, h) + + bbx2 = np.clip(cx + cut_w // 2, 0, w) + bby2 = np.clip(cy + cut_h // 2, 0, h) + + return bbx1, bby1, bbx2, bby2 + + target_device = images.device + mix_data = images.clone() + mix_masks = masks.clone() + u_rand_index = torch.randperm(images.size()[0])[: images.size()[0]].to(target_device) + u_bbx1, u_bby1, u_bbx2, u_bby2 = rand_bbox(images.size(), lam=np.random.beta(4, 4)) + + for i in range(mix_data.shape[0]): + mix_data[i, :, u_bbx1[i] : u_bbx2[i], u_bby1[i] : u_bby2[i]] = images[ + u_rand_index[i], + :, + u_bbx1[i] : u_bbx2[i], + u_bby1[i] : u_bby2[i], + ] + + mix_masks[i, :, u_bbx1[i] : u_bbx2[i], u_bby1[i] : u_bby2[i]] = masks[ + u_rand_index[i], + :, + u_bbx1[i] : u_bbx2[i], + u_bby1[i] : u_bby2[i], + ] + + del images, masks + + return mix_data, mix_masks.squeeze(dim=1) diff --git a/src/otx/algo/segmentation/__init__.py b/src/otx/algo/segmentation/__init__.py index db37ae66fa0..a4b4f86018a 100644 --- a/src/otx/algo/segmentation/__init__.py +++ b/src/otx/algo/segmentation/__init__.py @@ -3,6 +3,6 @@ # """Module for OTX segmentation models, hooks, utils, etc.""" -from . import backbones, heads, losses +from . import backbones, heads, losses, segmentors -__all__ = ["backbones", "heads", "losses"] +__all__ = ["backbones", "heads", "losses", "segmentors"] diff --git a/src/otx/algo/segmentation/dino_v2_seg.py b/src/otx/algo/segmentation/dino_v2_seg.py index d38001ada88..862f4e47ff4 100644 --- a/src/otx/algo/segmentation/dino_v2_seg.py +++ b/src/otx/algo/segmentation/dino_v2_seg.py @@ -7,12 +7,14 @@ from typing import TYPE_CHECKING, Any, ClassVar +import torch + from otx.algo.segmentation.backbones import DinoVisionTransformer from otx.algo.segmentation.heads import FCNHead +from otx.algo.segmentation.segmentors import BaseSegmModel, MeanTeacher +from otx.core.data.entity.segmentation import SegBatchDataEntity from otx.core.model.segmentation import TorchVisionCompatibleModel -from .base_model import BaseSegmModel - if TYPE_CHECKING: from torch import nn from typing_extensions import Self @@ -68,3 +70,46 @@ def to(self, *args, **kwargs) -> Self: msg = f"{type(self).__name__} doesn't support XPU." raise RuntimeError(msg) return ret + + +class DinoV2SegSemiSL(OTXDinoV2Seg): + """DinoV2SegSemiSL Model.""" + + def _customize_inputs(self, entity: SegBatchDataEntity) -> dict[str, Any]: + if not isinstance(entity, dict): + if self.training: + msg = "unlabeled inputs should be provided for semi-sl training" + raise RuntimeError(msg) + return super()._customize_inputs(entity) + + entity["labeled"].masks = torch.stack(entity["labeled"].masks).long() + w_u_images = entity["weak_transforms"].images + s_u_images = entity["strong_transforms"].images + unlabeled_img_metas = entity["weak_transforms"].imgs_info + labeled_inputs = entity["labeled"] + + return { + "inputs": labeled_inputs.images, + "unlabeled_weak_images": w_u_images, + "unlabeled_strong_images": s_u_images, + "global_step": self.trainer.global_step, + "steps_per_epoch": self.trainer.num_training_batches, + "img_metas": labeled_inputs.imgs_info, + "unlabeled_img_metas": unlabeled_img_metas, + "masks": labeled_inputs.masks, + "mode": "loss", + } + + def _create_model(self) -> nn.Module: + # merge configurations with defaults overriding them + backbone_configuration = DinoV2Seg.default_backbone_configuration | self.backbone_configuration + decode_head_configuration = DinoV2Seg.default_decode_head_configuration | self.decode_head_configuration + backbone = DinoVisionTransformer(**backbone_configuration) + decode_head = FCNHead(num_classes=self.num_classes, **decode_head_configuration) + base_model = DinoV2Seg( + backbone=backbone, + decode_head=decode_head, + criterion_configuration=self.criterion_configuration, + ) + + return MeanTeacher(base_model, unsup_weight=0.7, drop_unrel_pixels_percent=20, semisl_start_epoch=2) diff --git a/src/otx/algo/segmentation/litehrnet.py b/src/otx/algo/segmentation/litehrnet.py index 6684c63ed3b..e16f3d4d258 100644 --- a/src/otx/algo/segmentation/litehrnet.py +++ b/src/otx/algo/segmentation/litehrnet.py @@ -7,17 +7,18 @@ from typing import TYPE_CHECKING, Any, ClassVar +import torch from torch.onnx import OperatorExportTypes from otx.algo.segmentation.backbones import LiteHRNet from otx.algo.segmentation.heads import FCNHead +from otx.algo.segmentation.segmentors import BaseSegmModel, MeanTeacher from otx.algo.utils.support_otx_v1 import OTXv1Helper +from otx.core.data.entity.segmentation import SegBatchDataEntity from otx.core.exporter.base import OTXModelExporter from otx.core.exporter.native import OTXNativeModelExporter from otx.core.model.segmentation import TorchVisionCompatibleModel -from .base_model import BaseSegmModel - if TYPE_CHECKING: from torch import nn @@ -574,3 +575,51 @@ def _exporter(self) -> OTXModelExporter: onnx_export_configuration={"operator_export_type": OperatorExportTypes.ONNX_ATEN_FALLBACK}, output_names=None, ) + + +class LiteHRNetSemiSL(OTXLiteHRNet): + """LiteHRNetSemiSL Model.""" + + def _create_model(self) -> nn.Module: + litehrnet_model_class = LITEHRNET_VARIANTS[self.name_base_model] + # merge configurations with defaults overriding them + backbone_configuration = litehrnet_model_class.default_backbone_configuration | self.backbone_configuration + decode_head_configuration = ( + litehrnet_model_class.default_decode_head_configuration | self.decode_head_configuration + ) + # initialize backbones + backbone = LiteHRNet(**backbone_configuration) + decode_head = FCNHead(num_classes=self.num_classes, **decode_head_configuration) + + base_model = litehrnet_model_class( + backbone=backbone, + decode_head=decode_head, + criterion_configuration=self.criterion_configuration, + ) + + return MeanTeacher(base_model, unsup_weight=0.7, drop_unrel_pixels_percent=20, semisl_start_epoch=2) + + def _customize_inputs(self, entity: SegBatchDataEntity) -> dict[str, Any]: + if not isinstance(entity, dict): + if self.training: + msg = "unlabeled inputs should be provided for semi-sl training" + raise RuntimeError(msg) + return super()._customize_inputs(entity) + + entity["labeled"].masks = torch.stack(entity["labeled"].masks).long() + w_u_images = entity["weak_transforms"].images + s_u_images = entity["strong_transforms"].images + unlabeled_img_metas = entity["weak_transforms"].imgs_info + labeled_inputs = entity["labeled"] + + return { + "inputs": labeled_inputs.images, + "unlabeled_weak_images": w_u_images, + "unlabeled_strong_images": s_u_images, + "global_step": self.trainer.global_step, + "steps_per_epoch": self.trainer.num_training_batches, + "img_metas": labeled_inputs.imgs_info, + "unlabeled_img_metas": unlabeled_img_metas, + "masks": labeled_inputs.masks, + "mode": "loss", + } diff --git a/src/otx/algo/segmentation/segmentors/__init__.py b/src/otx/algo/segmentation/segmentors/__init__.py new file mode 100644 index 00000000000..7b7456cded1 --- /dev/null +++ b/src/otx/algo/segmentation/segmentors/__init__.py @@ -0,0 +1,9 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +"""Module for base NN segmentation models.""" + +from .base_model import BaseSegmModel +from .mean_teacher import MeanTeacher + +__all__ = ["BaseSegmModel", "MeanTeacher"] diff --git a/src/otx/algo/segmentation/base_model.py b/src/otx/algo/segmentation/segmentors/base_model.py similarity index 58% rename from src/otx/algo/segmentation/base_model.py rename to src/otx/algo/segmentation/segmentors/base_model.py index 057c9b3c1b4..c979be625ae 100644 --- a/src/otx/algo/segmentation/base_model.py +++ b/src/otx/algo/segmentation/segmentors/base_model.py @@ -1,4 +1,4 @@ -# Copyright (C) 2023 Intel Corporation +# Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 # """Base segmentation model.""" @@ -25,7 +25,7 @@ def __init__( decode_head: nn.Module, criterion_configuration: list[dict[str, str | Any]] | None = None, ) -> None: - """Initializes a SegNext model. + """Initializes a segmentation model. Args: backbone (nn.Module): The backbone of the segmentation model. @@ -54,10 +54,10 @@ def forward( """Performs the forward pass of the model. Args: - inputs: Input images to the model. - img_metas: Image meta information. Defaults to None. - masks: Ground truth masks for training. Defaults to None. - mode: The mode of operation. Defaults to "tensor". + inputs (Tensor): Input images to the model. + img_metas (list[ImageInfo]): Image meta information. Defaults to None. + masks (Tensor): Ground truth masks for training. Defaults to None. + mode (str): The mode of operation. Defaults to "tensor". Returns: Depending on the mode: @@ -67,12 +67,12 @@ def forward( - Otherwise, returns the model outputs after interpolation. """ enc_feats = self.backbone(inputs) - outputs = self.decode_head(inputs=enc_feats) + outputs = self.decode_head(enc_feats) + outputs = f.interpolate(outputs, size=inputs.size()[2:], mode="bilinear", align_corners=True) if mode == "tensor": return outputs - outputs = f.interpolate(outputs, size=inputs.size()[-2:], mode="bilinear", align_corners=True) if mode == "loss": if masks is None: msg = "The masks must be provided for training." @@ -80,32 +80,56 @@ def forward( if img_metas is None: msg = "The image meta information must be provided for training." raise ValueError(msg) - # class incremental training - valid_label_mask = self.get_valid_label_mask(img_metas) - output_losses = {} - for criterion in self.criterions: - valid_label_mask_cfg = {} - if criterion.name == "loss_ce_ignore": - valid_label_mask_cfg["valid_label_mask"] = valid_label_mask - if criterion.name not in output_losses: - output_losses[criterion.name] = criterion( - outputs, - masks, - **valid_label_mask_cfg, - ) - else: - output_losses[criterion.name] += criterion( - outputs, - masks, - **valid_label_mask_cfg, - ) - return output_losses + return self.calculate_loss(outputs, img_metas, masks, interpolate=False) if mode == "predict": return outputs.argmax(dim=1) return outputs + def calculate_loss( + self, + model_features: Tensor, + img_metas: list[ImageInfo], + masks: Tensor, + interpolate: bool, + ) -> Tensor: + """Calculates the loss of the model. + + Args: + model_features (Tensor): model outputs of the model. + img_metas (list[ImageInfo]): Image meta information. Defaults to None. + masks (Tensor): Ground truth masks for training. Defaults to None. + + Returns: + Tensor: The loss of the model. + """ + outputs = ( + f.interpolate(model_features, size=img_metas[0].img_shape, mode="bilinear", align_corners=True) + if interpolate + else model_features + ) + # class incremental training + valid_label_mask = self.get_valid_label_mask(img_metas) + output_losses = {} + for criterion in self.criterions: + valid_label_mask_cfg = {} + if criterion.name == "loss_ce_ignore": + valid_label_mask_cfg["valid_label_mask"] = valid_label_mask + if criterion.name not in output_losses: + output_losses[criterion.name] = criterion( + outputs, + masks, + **valid_label_mask_cfg, + ) + else: + output_losses[criterion.name] += criterion( + outputs, + masks, + **valid_label_mask_cfg, + ) + return output_losses + def get_valid_label_mask(self, img_metas: list[ImageInfo]) -> list[Tensor]: """Get valid label mask removing ignored classes to zero mask in a batch. diff --git a/src/otx/algo/segmentation/segmentors/mean_teacher.py b/src/otx/algo/segmentation/segmentors/mean_teacher.py new file mode 100644 index 00000000000..45b344d884d --- /dev/null +++ b/src/otx/algo/segmentation/segmentors/mean_teacher.py @@ -0,0 +1,153 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +"""Base mean teacher algorithm for semi-supervised semantic segmentation learning.""" +from __future__ import annotations + +import copy +from typing import TYPE_CHECKING + +import numpy as np +import torch +from torch import Tensor, nn + +from otx.algo.common.utils.utils import cut_mixer + +if TYPE_CHECKING: + from otx.core.data.entity.base import ImageInfo + + +class MeanTeacher(nn.Module): + """MeanTeacher for Semi-supervised learning. + + Args: + model (nn.Module): model + unsup_weight (float, optional): unsupervised weight. Defaults to 1.0. + drop_unrel_pixels_percent (int, optional): drop unrel pixels percent. Defaults to 20. + semisl_start_epoch (int, optional): semisl start epoch. Defaults to 0. + filter_pixels_epochs (int, optional): filter pixels epochs. Defaults to 100. + """ + + def __init__( + self, + model: nn.Module, + unsup_weight: float = 1.0, + drop_unrel_pixels_percent: int = 20, + semisl_start_epoch: int = 0, + filter_pixels_epochs: int = 100, + ) -> None: + super().__init__() + + self.teacher_model = model + self.student_model = copy.deepcopy(model) + # no grads for teacher model + for param in self.teacher_model.parameters(): + param.requires_grad = False + self.unsup_weight = unsup_weight + self.drop_unrel_pixels_percent = drop_unrel_pixels_percent + # filter unreliable pixels during first X epochs + self.filter_pixels_epochs = filter_pixels_epochs + self.semisl_start_epoch = semisl_start_epoch + + def forward( + self, + inputs: Tensor, + unlabeled_weak_images: Tensor | None = None, + unlabeled_strong_images: Tensor | None = None, + global_step: int | None = None, + steps_per_epoch: int | None = None, + img_metas: list[ImageInfo] | None = None, + unlabeled_img_metas: list[ImageInfo] | None = None, + masks: Tensor | None = None, + mode: str = "tensor", + ) -> Tensor: + """Step for model training. + + Args: + inputs (Tensor): input labeled images + unlabeled_weak_images (Tensor, optional): unlabeled images with weak augmentations. Defaults to None. + unlabeled_strong_images (Tensor, optional): unlabeled images with strong augmentations. Defaults to None. + global_step (int, optional): global step. Defaults to None. + steps_per_epoch (int, optional): steps per epoch. Defaults to None. + img_metas (list[ImageInfo], optional): image meta information. Defaults to None. + unlabeled_img_metas (list[ImageInfo], optional): unlabeled image meta information. Defaults to None. + masks (Tensor, optional): ground truth masks for training. Defaults to None. + mode (str, optional): mode of forward. Defaults to "tensor". + """ + if mode != "loss": + # only labeled images for validation and testing + return self.teacher_model(inputs, img_metas, masks, mode=mode) + + if global_step is None or steps_per_epoch is None: + msg = "global_step and steps_per_epoch should be provided" + raise ValueError(msg) + + if global_step > self.semisl_start_epoch * steps_per_epoch: + # generate pseudo labels, filter high entropy pixels, compute loss reweight + percent_unreliable = self.drop_unrel_pixels_percent * ( + 1 - global_step / self.filter_pixels_epochs * steps_per_epoch + ) + pl_from_teacher, reweight_unsup = self._generate_pseudo_labels( + unlabeled_weak_images, + percent_unreliable=percent_unreliable, + ) + unlabeled_strong_images_aug, pl_from_teacher_aug = cut_mixer(unlabeled_strong_images, pl_from_teacher) + # extract features from labeled and unlabeled augmented images + student_labeled_logits = self.student_model(inputs, mode="tensor") + student_unlabeled_logits = self.student_model(unlabeled_strong_images_aug, mode="tensor") + # loss computation + loss_decode = self.student_model.calculate_loss( + student_labeled_logits, + img_metas, + masks=masks, + interpolate=True, + ) + loss_decode_u = self.student_model.calculate_loss( + student_unlabeled_logits, + unlabeled_img_metas, + masks=pl_from_teacher_aug, + interpolate=True, + ) + loss_decode_u = {f"{k}_unlabeled": v * reweight_unsup * self.unsup_weight for k, v in loss_decode_u.items()} + loss_decode.update(loss_decode_u) + return loss_decode + + return self.student_model(inputs, img_metas, masks, mode="loss") + + def _generate_pseudo_labels(self, ul_w_img: Tensor, percent_unreliable: float) -> tuple[Tensor, Tensor]: + """Generate pseudo labels from teacher model, apply filter loss method. + + Args: + ul_w_img (torch.Tensor): weakly augmented unlabeled images + ul_img_metas (list[ImageInfo]): unlabeled images meta data + percent_unreliable (float): percent of unreliable pixels + + """ + with torch.no_grad(): + teacher_out = self.teacher_model(ul_w_img, mode="tensor") + teacher_out = torch.nn.functional.interpolate( + teacher_out, + size=ul_w_img.shape[2:], + mode="bilinear", + align_corners=True, + ) + teacher_prob_unsup = torch.softmax(teacher_out, axis=1) + _, pl_from_teacher = torch.max(teacher_prob_unsup, axis=1, keepdim=True) + + # drop pixels with high entropy + reweight_unsup = 1.0 + if percent_unreliable > 0: + keep_percent = 100 - percent_unreliable + batch_size, _, h, w = teacher_out.shape + + entropy = -torch.sum(teacher_prob_unsup * torch.log(teacher_prob_unsup + 1e-10), dim=1, keepdim=True) + + thresh = np.percentile(entropy[pl_from_teacher != 255].detach().cpu().numpy().flatten(), keep_percent) + thresh_mask = entropy.ge(thresh).bool() * (pl_from_teacher != 255).bool() + + # mark as ignore index + pl_from_teacher[thresh_mask] = 255 + # reweight unsupervised loss + reweight_unsup = batch_size * h * w / torch.sum(pl_from_teacher != 255) + + return pl_from_teacher, reweight_unsup diff --git a/src/otx/algo/segmentation/segnext.py b/src/otx/algo/segmentation/segnext.py index 0c2eaff739b..f7421b6bee0 100644 --- a/src/otx/algo/segmentation/segnext.py +++ b/src/otx/algo/segmentation/segnext.py @@ -7,15 +7,16 @@ from typing import Any, ClassVar +import torch from torch import nn from otx.algo.segmentation.backbones import MSCAN from otx.algo.segmentation.heads import LightHamHead +from otx.algo.segmentation.segmentors import BaseSegmModel, MeanTeacher from otx.algo.utils.support_otx_v1 import OTXv1Helper +from otx.core.data.entity.segmentation import SegBatchDataEntity from otx.core.model.segmentation import TorchVisionCompatibleModel -from .base_model import BaseSegmModel - class SegNextB(BaseSegmModel): """SegNextB Model.""" @@ -143,3 +144,49 @@ def _optimization_config(self) -> dict[str, Any]: ], }, } + + +class SemiSLSegNext(OTXSegNext): + """SegNext Model.""" + + def _customize_inputs(self, entity: SegBatchDataEntity) -> dict[str, Any]: + if not isinstance(entity, dict): + if self.training: + msg = "unlabeled inputs should be provided for semi-sl training" + raise RuntimeError(msg) + return super()._customize_inputs(entity) + + entity["labeled"].masks = torch.stack(entity["labeled"].masks).long() + w_u_images = entity["weak_transforms"].images + s_u_images = entity["strong_transforms"].images + unlabeled_img_metas = entity["weak_transforms"].imgs_info + labeled_inputs = entity["labeled"] + + return { + "inputs": labeled_inputs.images, + "unlabeled_weak_images": w_u_images, + "unlabeled_strong_images": s_u_images, + "global_step": self.trainer.global_step, + "steps_per_epoch": self.trainer.num_training_batches, + "img_metas": labeled_inputs.imgs_info, + "unlabeled_img_metas": unlabeled_img_metas, + "masks": labeled_inputs.masks, + "mode": "loss", + } + + def _create_model(self) -> nn.Module: + segnext_model_class = SEGNEXT_VARIANTS[self.name_base_model] + # merge configurations with defaults overriding them + backbone_configuration = segnext_model_class.default_backbone_configuration | self.backbone_configuration + decode_head_configuration = ( + segnext_model_class.default_decode_head_configuration | self.decode_head_configuration + ) + # initialize backbones + backbone = MSCAN(**backbone_configuration) + decode_head = LightHamHead(num_classes=self.num_classes, **decode_head_configuration) + base_model = segnext_model_class( + backbone=backbone, + decode_head=decode_head, + criterion_configuration=self.criterion_configuration, + ) + return MeanTeacher(base_model, unsup_weight=0.7, drop_unrel_pixels_percent=20, semisl_start_epoch=2) diff --git a/src/otx/core/data/dataset/segmentation.py b/src/otx/core/data/dataset/segmentation.py index 3651e961d23..363a15e84cc 100644 --- a/src/otx/core/data/dataset/segmentation.py +++ b/src/otx/core/data/dataset/segmentation.py @@ -10,6 +10,7 @@ import cv2 import numpy as np +import torch from datumaro.components.annotation import Ellipse, Image, Mask, Polygon from torchvision import tv_tensors @@ -202,7 +203,12 @@ def _get_item_impl(self, index: int) -> SegDataEntity | None: img = item.media_as(Image) ignored_labels: list[int] = [] img_data, img_shape = self._get_img_data_and_shape(img) - mask = _extract_class_mask(item=item, img_shape=img_shape, ignore_index=self.ignore_index) + if item.annotations: + extracted_mask = _extract_class_mask(item=item, img_shape=img_shape, ignore_index=self.ignore_index) + masks = tv_tensors.Mask(extracted_mask[None]) + else: + # semi-supervised learning, unlabeled dataset + masks = torch.tensor([[0]]) entity = SegDataEntity( image=img_data, @@ -213,7 +219,7 @@ def _get_item_impl(self, index: int) -> SegDataEntity | None: image_color_channel=self.image_color_channel, ignored_labels=ignored_labels, ), - masks=tv_tensors.Mask(mask[None]), + masks=masks, ) transformed_entity = self._apply_transforms(entity) return transformed_entity.wrap(masks=transformed_entity.masks[0]) if transformed_entity else None diff --git a/src/otx/core/model/segmentation.py b/src/otx/core/model/segmentation.py index 282ee02864f..d4affefed6e 100644 --- a/src/otx/core/model/segmentation.py +++ b/src/otx/core/model/segmentation.py @@ -20,10 +20,13 @@ from otx.core.metrics.dice import SegmCallable from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable, OTXModel, OVModel from otx.core.schedulers import LRSchedulerListCallable -from otx.core.types.export import TaskLevelExportParameters +from otx.core.types.export import OTXExportFormatType, TaskLevelExportParameters from otx.core.types.label import LabelInfo, LabelInfoTypes, SegLabelInfo +from otx.core.types.precision import OTXPrecisionType if TYPE_CHECKING: + from pathlib import Path + from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable from model_api.models.utils import ImageResultWithSoftPrediction from torch import Tensor @@ -122,6 +125,32 @@ def get_dummy_input(self, batch_size: int = 1) -> SegBatchDataEntity: ) return SegBatchDataEntity(batch_size, images, infos, masks=[]) + def export( + self, + output_dir: Path, + base_name: str, + export_format: OTXExportFormatType, + precision: OTXPrecisionType = OTXPrecisionType.FP32, + to_exportable_code: bool = False, + ) -> Path: + """Export this model to the specified output directory. + + Args: + output_dir (Path): directory for saving the exported model + base_name: (str): base name for the exported model file. Extension is defined by the target export format + export_format (OTXExportFormatType): format of the output model + precision (OTXExportPrecisionType): precision of the output model + to_exportable_code (bool): flag to export model in exportable code with demo package + + Returns: + Path: path to the exported model. + """ + if hasattr(self.model, "student_model"): + # use only teacher model + # TODO(Kirill): make this based on the training type + self.model = self.model.teacher_model + return super().export(output_dir, base_name, export_format, precision, to_exportable_code) + class TorchVisionCompatibleModel(OTXSegmentationModel): """Segmentation model compatible with torchvision data pipeline.""" diff --git a/src/otx/core/types/label.py b/src/otx/core/types/label.py index 6b4ff83218f..cd472965336 100644 --- a/src/otx/core/types/label.py +++ b/src/otx/core/types/label.py @@ -282,14 +282,6 @@ class SegLabelInfo(LabelInfo): ignore_index: int = 255 - def __post_init__(self): - if len(self.label_names) <= 1: - msg = ( - "The number of labels must be larger than 1. " - "Please, check dataset labels and add background label in case of binary segmentation." - ) - raise ValueError(msg) - @classmethod def from_num_classes(cls, num_classes: int) -> LabelInfo: """Create this object from the number of classes. diff --git a/src/otx/engine/utils/api.py b/src/otx/engine/utils/api.py index 0cf036a6cc9..48c89627d02 100644 --- a/src/otx/engine/utils/api.py +++ b/src/otx/engine/utils/api.py @@ -47,7 +47,9 @@ def list_models(task: OTXTaskType | None = None, pattern: str | None = None, pri >>> models = list_models(task="MULTI_CLASS_CLS", pattern="*efficient", print_table=True) """ task_type = OTXTaskType(task).name.lower() if task is not None else "**" - recipe_list = [str(recipe) for recipe in RECIPE_PATH.glob(f"**/{task_type}/*.yaml") if "_base_" not in recipe.parts] + recipe_list = [ + str(recipe) for recipe in RECIPE_PATH.glob(f"**/{task_type}/**/*.yaml") if "_base_" not in recipe.parts + ] if pattern is not None: # Always match keys with any postfix. diff --git a/src/otx/recipe/_base_/data/semisl/semantic_segmentation_semisl.yaml b/src/otx/recipe/_base_/data/semisl/semantic_segmentation_semisl.yaml new file mode 100644 index 00000000000..ec3d647ab39 --- /dev/null +++ b/src/otx/recipe/_base_/data/semisl/semantic_segmentation_semisl.yaml @@ -0,0 +1,147 @@ +task: SEMANTIC_SEGMENTATION +input_size: + - 512 + - 512 +mem_cache_size: 1GB +mem_cache_img_max_size: null +image_color_channel: RGB +data_format: common_semantic_segmentation_with_subset_dirs +include_polygons: true +unannotated_items_ratio: 0.0 +ignore_index: 255 +train_subset: + subset_name: train + batch_size: 8 + num_workers: 4 + transform_lib_type: TORCHVISION + to_tv_image: false + transforms: + - class_path: otx.core.data.transform_libs.torchvision.RandomResizedCrop + init_args: + scale: $(input_size) + crop_ratio_range: + - 0.2 + - 1.0 + aspect_ratio_range: + - 0.5 + - 2.0 + transform_mask: true + - class_path: otx.core.data.transform_libs.torchvision.RandomFlip + init_args: + prob: 0.5 + is_numpy_to_tvtensor: true + - class_path: torchvision.transforms.v2.RandomPhotometricDistort + - class_path: torchvision.transforms.v2.ToDtype + init_args: + dtype: ${as_torch_dtype:torch.float32} + - class_path: torchvision.transforms.v2.Normalize + init_args: + mean: [123.675, 116.28, 103.53] + std: [58.395, 57.12, 57.375] + sampler: + class_path: torch.utils.data.RandomSampler + +val_subset: + subset_name: val + batch_size: 8 + num_workers: 4 + transform_lib_type: TORCHVISION + to_tv_image: false + transforms: + - class_path: otx.core.data.transform_libs.torchvision.Resize + init_args: + scale: $(input_size) + transform_mask: true + is_numpy_to_tvtensor: true + - class_path: torchvision.transforms.v2.ToDtype + init_args: + dtype: ${as_torch_dtype:torch.float32} + - class_path: torchvision.transforms.v2.Normalize + init_args: + mean: [123.675, 116.28, 103.53] + std: [58.395, 57.12, 57.375] + sampler: + class_path: torch.utils.data.RandomSampler + +test_subset: + subset_name: test + num_workers: 4 + batch_size: 8 + transform_lib_type: TORCHVISION + to_tv_image: false + transforms: + - class_path: otx.core.data.transform_libs.torchvision.Resize + init_args: + scale: $(input_size) + transform_mask: true + is_numpy_to_tvtensor: true + - class_path: torchvision.transforms.v2.ToDtype + init_args: + dtype: ${as_torch_dtype:torch.float32} + - class_path: torchvision.transforms.v2.Normalize + init_args: + mean: [123.675, 116.28, 103.53] + std: [58.395, 57.12, 57.375] + sampler: + class_path: torch.utils.data.RandomSampler + +unlabeled_subset: + data_format: image_dir + batch_size: 8 + subset_name: unlabeled + to_tv_image: false + transforms: + weak_transforms: + - class_path: otx.core.data.transform_libs.torchvision.Resize + init_args: + scale: $(input_size) + transform_mask: false + is_numpy_to_tvtensor: true + - class_path: torchvision.transforms.v2.ToDtype + init_args: + dtype: ${as_torch_dtype:torch.float32} + scale: false + - class_path: torchvision.transforms.v2.Normalize + init_args: + mean: + - 123.675 + - 116.28 + - 103.53 + std: + - 58.395 + - 57.12 + - 57.375 + + strong_transforms: + - class_path: otx.core.data.transform_libs.torchvision.Resize + init_args: + scale: $(input_size) + transform_mask: false + is_numpy_to_tvtensor: true + - class_path: torchvision.transforms.v2.RandomPhotometricDistort + init_args: + p: 1.0 + - class_path: torchvision.transforms.v2.RandomPosterize + init_args: + bits: 5 + - class_path: torchvision.transforms.v2.RandomEqualize + - class_path: torchvision.transforms.v2.ToDtype + init_args: + dtype: ${as_torch_dtype:torch.float32} + scale: false + - class_path: torchvision.transforms.v2.Normalize + init_args: + mean: + - 123.675 + - 116.28 + - 103.53 + std: + - 58.395 + - 57.12 + - 57.375 + + transform_lib_type: TORCHVISION + num_workers: 4 + sampler: + class_path: torch.utils.data.RandomSampler + init_args: {} diff --git a/src/otx/recipe/_base_/data/torchvision_semisl.yaml b/src/otx/recipe/_base_/data/semisl/torchvision_semisl.yaml similarity index 100% rename from src/otx/recipe/_base_/data/torchvision_semisl.yaml rename to src/otx/recipe/_base_/data/semisl/torchvision_semisl.yaml diff --git a/src/otx/recipe/classification/multi_class_cls/deit_tiny_semisl.yaml b/src/otx/recipe/classification/multi_class_cls/semisl/deit_tiny_semisl.yaml similarity index 93% rename from src/otx/recipe/classification/multi_class_cls/deit_tiny_semisl.yaml rename to src/otx/recipe/classification/multi_class_cls/semisl/deit_tiny_semisl.yaml index 2baf20fd5d5..34155bf6089 100644 --- a/src/otx/recipe/classification/multi_class_cls/deit_tiny_semisl.yaml +++ b/src/otx/recipe/classification/multi_class_cls/semisl/deit_tiny_semisl.yaml @@ -25,7 +25,7 @@ engine: callback_monitor: val/accuracy -data: ../../_base_/data/torchvision_semisl.yaml +data: ../../../_base_/data/semisl/torchvision_semisl.yaml overrides: max_epochs: 200 callbacks: diff --git a/src/otx/recipe/classification/multi_class_cls/dino_v2_semisl.yaml b/src/otx/recipe/classification/multi_class_cls/semisl/dino_v2_semisl.yaml similarity index 93% rename from src/otx/recipe/classification/multi_class_cls/dino_v2_semisl.yaml rename to src/otx/recipe/classification/multi_class_cls/semisl/dino_v2_semisl.yaml index 1a5f2c59e30..9b06bfb51e1 100644 --- a/src/otx/recipe/classification/multi_class_cls/dino_v2_semisl.yaml +++ b/src/otx/recipe/classification/multi_class_cls/semisl/dino_v2_semisl.yaml @@ -25,7 +25,7 @@ engine: callback_monitor: val/accuracy -data: ../../_base_/data/torchvision_semisl.yaml +data: ../../../_base_/data/semisl/torchvision_semisl.yaml overrides: max_epochs: 200 callbacks: diff --git a/src/otx/recipe/classification/multi_class_cls/efficientnet_b0_semisl.yaml b/src/otx/recipe/classification/multi_class_cls/semisl/efficientnet_b0_semisl.yaml similarity index 94% rename from src/otx/recipe/classification/multi_class_cls/efficientnet_b0_semisl.yaml rename to src/otx/recipe/classification/multi_class_cls/semisl/efficientnet_b0_semisl.yaml index bb412f0ae7f..30126060f6e 100644 --- a/src/otx/recipe/classification/multi_class_cls/efficientnet_b0_semisl.yaml +++ b/src/otx/recipe/classification/multi_class_cls/semisl/efficientnet_b0_semisl.yaml @@ -26,7 +26,7 @@ engine: callback_monitor: val/accuracy -data: ../../_base_/data/torchvision_semisl.yaml +data: ../../../_base_/data/semisl/torchvision_semisl.yaml overrides: max_epochs: 200 callbacks: diff --git a/src/otx/recipe/classification/multi_class_cls/efficientnet_v2_semisl.yaml b/src/otx/recipe/classification/multi_class_cls/semisl/efficientnet_v2_semisl.yaml similarity index 94% rename from src/otx/recipe/classification/multi_class_cls/efficientnet_v2_semisl.yaml rename to src/otx/recipe/classification/multi_class_cls/semisl/efficientnet_v2_semisl.yaml index 0dbf00e2222..b1f87665bde 100644 --- a/src/otx/recipe/classification/multi_class_cls/efficientnet_v2_semisl.yaml +++ b/src/otx/recipe/classification/multi_class_cls/semisl/efficientnet_v2_semisl.yaml @@ -26,7 +26,7 @@ engine: callback_monitor: val/accuracy -data: ../../_base_/data/torchvision_semisl.yaml +data: ../../../_base_/data/semisl/torchvision_semisl.yaml overrides: max_epochs: 200 callbacks: diff --git a/src/otx/recipe/classification/multi_class_cls/mobilenet_v3_large_semisl.yaml b/src/otx/recipe/classification/multi_class_cls/semisl/mobilenet_v3_large_semisl.yaml similarity index 94% rename from src/otx/recipe/classification/multi_class_cls/mobilenet_v3_large_semisl.yaml rename to src/otx/recipe/classification/multi_class_cls/semisl/mobilenet_v3_large_semisl.yaml index 6fc42696d0f..e0dab91d687 100644 --- a/src/otx/recipe/classification/multi_class_cls/mobilenet_v3_large_semisl.yaml +++ b/src/otx/recipe/classification/multi_class_cls/semisl/mobilenet_v3_large_semisl.yaml @@ -30,7 +30,7 @@ engine: callback_monitor: val/accuracy -data: ../../_base_/data/torchvision_semisl.yaml +data: ../../../_base_/data/semisl/torchvision_semisl.yaml overrides: max_epochs: 200 callbacks: diff --git a/src/otx/recipe/classification/multi_class_cls/tv_efficientnet_b3_semisl.yaml b/src/otx/recipe/classification/multi_class_cls/semisl/tv_efficientnet_b3_semisl.yaml similarity index 93% rename from src/otx/recipe/classification/multi_class_cls/tv_efficientnet_b3_semisl.yaml rename to src/otx/recipe/classification/multi_class_cls/semisl/tv_efficientnet_b3_semisl.yaml index 0ab09b9c658..a2b3c1d73d5 100644 --- a/src/otx/recipe/classification/multi_class_cls/tv_efficientnet_b3_semisl.yaml +++ b/src/otx/recipe/classification/multi_class_cls/semisl/tv_efficientnet_b3_semisl.yaml @@ -25,7 +25,7 @@ engine: callback_monitor: val/accuracy -data: ../../_base_/data/torchvision_semisl.yaml +data: ../../../_base_/data/semisl/torchvision_semisl.yaml overrides: max_epochs: 200 diff --git a/src/otx/recipe/classification/multi_class_cls/tv_efficientnet_v2_l_semisl.yaml b/src/otx/recipe/classification/multi_class_cls/semisl/tv_efficientnet_v2_l_semisl.yaml similarity index 93% rename from src/otx/recipe/classification/multi_class_cls/tv_efficientnet_v2_l_semisl.yaml rename to src/otx/recipe/classification/multi_class_cls/semisl/tv_efficientnet_v2_l_semisl.yaml index 28763900b04..ffe7e25a99b 100644 --- a/src/otx/recipe/classification/multi_class_cls/tv_efficientnet_v2_l_semisl.yaml +++ b/src/otx/recipe/classification/multi_class_cls/semisl/tv_efficientnet_v2_l_semisl.yaml @@ -25,7 +25,7 @@ engine: callback_monitor: val/accuracy -data: ../../_base_/data/torchvision_semisl.yaml +data: ../../../_base_/data/semisl/torchvision_semisl.yaml overrides: max_epochs: 200 diff --git a/src/otx/recipe/classification/multi_class_cls/tv_mobilenet_v3_small_semisl.yaml b/src/otx/recipe/classification/multi_class_cls/semisl/tv_mobilenet_v3_small_semisl.yaml similarity index 93% rename from src/otx/recipe/classification/multi_class_cls/tv_mobilenet_v3_small_semisl.yaml rename to src/otx/recipe/classification/multi_class_cls/semisl/tv_mobilenet_v3_small_semisl.yaml index e98ee3b1af2..824881433e4 100644 --- a/src/otx/recipe/classification/multi_class_cls/tv_mobilenet_v3_small_semisl.yaml +++ b/src/otx/recipe/classification/multi_class_cls/semisl/tv_mobilenet_v3_small_semisl.yaml @@ -25,7 +25,7 @@ engine: callback_monitor: val/accuracy -data: ../../_base_/data/torchvision_semisl.yaml +data: ../../../_base_/data/semisl/torchvision_semisl.yaml overrides: max_epochs: 200 diff --git a/src/otx/recipe/semantic_segmentation/semisl/dino_v2_semisl.yaml b/src/otx/recipe/semantic_segmentation/semisl/dino_v2_semisl.yaml new file mode 100644 index 00000000000..8902a549cde --- /dev/null +++ b/src/otx/recipe/semantic_segmentation/semisl/dino_v2_semisl.yaml @@ -0,0 +1,83 @@ +model: + class_path: otx.algo.segmentation.dino_v2_seg.DinoV2SegSemiSL + init_args: + label_info: 2 + + criterion_configuration: + - type: CrossEntropyLoss + params: + ignore_index: 255 + + optimizer: + class_path: torch.optim.AdamW + init_args: + lr: 0.001 + betas: + - 0.9 + - 0.999 + weight_decay: 0.0001 + + export_image_configuration: + image_size: + - 1 + - 3 + - 560 + - 560 + + scheduler: + class_path: torch.optim.lr_scheduler.PolynomialLR + init_args: + total_iters: 100 + power: 0.9 + last_epoch: -1 + +engine: + task: SEMANTIC_SEGMENTATION + device: auto + +callback_monitor: val/Dice + +data: ../../_base_/data/semisl/semantic_segmentation_semisl.yaml +overrides: + data: + input_size: + - 560 + - 560 + train_subset: + transforms: + - class_path: otx.core.data.transform_libs.torchvision.RandomResizedCrop + init_args: + scale: $(input_size) + + val_subset: + transforms: + - class_path: otx.core.data.transform_libs.torchvision.Resize + init_args: + scale: $(input_size) + + test_subset: + transforms: + - class_path: otx.core.data.transform_libs.torchvision.Resize + init_args: + scale: $(input_size) + + unlabeled_subset: + transforms: + weak_transforms: + - class_path: otx.core.data.transform_libs.torchvision.Resize + init_args: + scale: $(input_size) + + strong_transforms: + - class_path: otx.core.data.transform_libs.torchvision.Resize + init_args: + scale: $(input_size) + + callbacks: + - class_path: otx.algo.callbacks.adaptive_early_stopping.EarlyStoppingWithWarmup + init_args: + warmup_iters: 100 + - class_path: otx.algo.callbacks.ema_mean_teacher.EMAMeanTeacher + init_args: + momentum: 0.99 + start_epoch: 2 diff --git a/src/otx/recipe/semantic_segmentation/semisl/litehrnet_18_semisl.yaml b/src/otx/recipe/semantic_segmentation/semisl/litehrnet_18_semisl.yaml new file mode 100644 index 00000000000..8a6c02a5fd4 --- /dev/null +++ b/src/otx/recipe/semantic_segmentation/semisl/litehrnet_18_semisl.yaml @@ -0,0 +1,48 @@ +model: + class_path: otx.algo.segmentation.litehrnet.LiteHRNetSemiSL + init_args: + label_info: 2 + name_base_model: LiteHRNet18 + + criterion_configuration: + - type: CrossEntropyLoss + params: + ignore_index: 255 + + optimizer: + class_path: torch.optim.Adam + init_args: + lr: 0.001 + betas: + - 0.9 + - 0.999 + weight_decay: 0.0 + + scheduler: + class_path: otx.core.schedulers.LinearWarmupSchedulerCallable + init_args: + num_warmup_steps: 100 + main_scheduler_callable: + class_path: lightning.pytorch.cli.ReduceLROnPlateau + init_args: + mode: max + factor: 0.1 + patience: 4 + monitor: val/Dice + +engine: + task: SEMANTIC_SEGMENTATION + device: auto + +callback_monitor: val/Dice + +data: ../../_base_/data/semisl/semantic_segmentation_semisl.yaml +overrides: + callbacks: + - class_path: otx.algo.callbacks.adaptive_early_stopping.EarlyStoppingWithWarmup + init_args: + warmup_iters: 100 + - class_path: otx.algo.callbacks.ema_mean_teacher.EMAMeanTeacher + init_args: + momentum: 0.99 + start_epoch: 2 diff --git a/src/otx/recipe/semantic_segmentation/semisl/litehrnet_s_semisl.yaml b/src/otx/recipe/semantic_segmentation/semisl/litehrnet_s_semisl.yaml new file mode 100644 index 00000000000..c34bf5436a4 --- /dev/null +++ b/src/otx/recipe/semantic_segmentation/semisl/litehrnet_s_semisl.yaml @@ -0,0 +1,48 @@ +model: + class_path: otx.algo.segmentation.litehrnet.LiteHRNetSemiSL + init_args: + label_info: 2 + name_base_model: LiteHRNetS + + criterion_configuration: + - type: CrossEntropyLoss + params: + ignore_index: 255 + + optimizer: + class_path: torch.optim.Adam + init_args: + lr: 0.001 + betas: + - 0.9 + - 0.999 + weight_decay: 0.0 + + scheduler: + class_path: otx.core.schedulers.LinearWarmupSchedulerCallable + init_args: + num_warmup_steps: 100 + main_scheduler_callable: + class_path: lightning.pytorch.cli.ReduceLROnPlateau + init_args: + mode: max + factor: 0.1 + patience: 4 + monitor: val/Dice + +engine: + task: SEMANTIC_SEGMENTATION + device: auto + +callback_monitor: val/Dice + +data: ../../_base_/data/semisl/semantic_segmentation_semisl.yaml +overrides: + callbacks: + - class_path: otx.algo.callbacks.adaptive_early_stopping.EarlyStoppingWithWarmup + init_args: + warmup_iters: 100 + - class_path: otx.algo.callbacks.ema_mean_teacher.EMAMeanTeacher + init_args: + momentum: 0.99 + start_epoch: 2 diff --git a/src/otx/recipe/semantic_segmentation/semisl/litehrnet_x_semisl.yaml b/src/otx/recipe/semantic_segmentation/semisl/litehrnet_x_semisl.yaml new file mode 100644 index 00000000000..a5f6e8f0606 --- /dev/null +++ b/src/otx/recipe/semantic_segmentation/semisl/litehrnet_x_semisl.yaml @@ -0,0 +1,48 @@ +model: + class_path: otx.algo.segmentation.litehrnet.LiteHRNetSemiSL + init_args: + label_info: 2 + name_base_model: LiteHRNetX + + criterion_configuration: + - type: CrossEntropyLoss + params: + ignore_index: 255 + + optimizer: + class_path: torch.optim.Adam + init_args: + lr: 0.001 + betas: + - 0.9 + - 0.999 + weight_decay: 0.0 + + scheduler: + class_path: otx.core.schedulers.LinearWarmupSchedulerCallable + init_args: + num_warmup_steps: 100 + main_scheduler_callable: + class_path: lightning.pytorch.cli.ReduceLROnPlateau + init_args: + mode: max + factor: 0.1 + patience: 4 + monitor: val/Dice + +engine: + task: SEMANTIC_SEGMENTATION + device: auto + +callback_monitor: val/Dice + +data: ../../_base_/data/semisl/semantic_segmentation_semisl.yaml +overrides: + callbacks: + - class_path: otx.algo.callbacks.adaptive_early_stopping.EarlyStoppingWithWarmup + init_args: + warmup_iters: 100 + - class_path: otx.algo.callbacks.ema_mean_teacher.EMAMeanTeacher + init_args: + momentum: 0.99 + start_epoch: 2 diff --git a/src/otx/recipe/semantic_segmentation/semisl/segnext_b_semisl.yaml b/src/otx/recipe/semantic_segmentation/semisl/segnext_b_semisl.yaml new file mode 100644 index 00000000000..d8557a58465 --- /dev/null +++ b/src/otx/recipe/semantic_segmentation/semisl/segnext_b_semisl.yaml @@ -0,0 +1,50 @@ +model: + class_path: otx.algo.segmentation.segnext.OTXSegNext + init_args: + label_info: 2 + name_base_model: SegNextB + + criterion_configuration: + - type: CrossEntropyLoss + params: + ignore_index: 255 + + optimizer: + class_path: torch.optim.AdamW + init_args: + lr: 0.00006 + betas: + - 0.9 + - 0.999 + weight_decay: 0.01 + + scheduler: + class_path: otx.core.schedulers.LinearWarmupSchedulerCallable + init_args: + num_warmup_steps: 20 + main_scheduler_callable: + class_path: torch.optim.lr_scheduler.PolynomialLR + init_args: + total_iters: 100 + power: 0.9 + last_epoch: -1 + +engine: + task: SEMANTIC_SEGMENTATION + device: auto + +callback_monitor: val/Dice + +data: ../../_base_/data/semisl/semantic_segmentation_semisl.yaml + +overrides: + model: + class_path: otx.algo.segmentation.segnext.SemiSLSegNext + callbacks: + - class_path: otx.algo.callbacks.adaptive_early_stopping.EarlyStoppingWithWarmup + init_args: + warmup_iters: 100 + - class_path: otx.algo.callbacks.ema_mean_teacher.EMAMeanTeacher + init_args: + momentum: 0.99 + start_epoch: 2 diff --git a/src/otx/recipe/semantic_segmentation/semisl/segnext_s_semisl.yaml b/src/otx/recipe/semantic_segmentation/semisl/segnext_s_semisl.yaml new file mode 100644 index 00000000000..231587d4835 --- /dev/null +++ b/src/otx/recipe/semantic_segmentation/semisl/segnext_s_semisl.yaml @@ -0,0 +1,47 @@ +model: + class_path: otx.algo.segmentation.segnext.SemiSLSegNext + init_args: + label_info: 2 + name_base_model: SegNextS + + criterion_configuration: + - type: CrossEntropyLoss + params: + ignore_index: 255 + + optimizer: + class_path: torch.optim.AdamW + init_args: + lr: 0.00006 + betas: + - 0.9 + - 0.999 + weight_decay: 0.01 + + scheduler: + class_path: otx.core.schedulers.LinearWarmupSchedulerCallable + init_args: + num_warmup_steps: 20 + main_scheduler_callable: + class_path: torch.optim.lr_scheduler.PolynomialLR + init_args: + total_iters: 100 + power: 0.9 + last_epoch: -1 + +engine: + task: SEMANTIC_SEGMENTATION + device: auto + +callback_monitor: val/Dice + +data: ../../_base_/data/semisl/semantic_segmentation_semisl.yaml +overrides: + callbacks: + - class_path: otx.algo.callbacks.adaptive_early_stopping.EarlyStoppingWithWarmup + init_args: + warmup_iters: 100 + - class_path: otx.algo.callbacks.ema_mean_teacher.EMAMeanTeacher + init_args: + momentum: 0.99 + start_epoch: 2 diff --git a/src/otx/recipe/semantic_segmentation/semisl/segnext_t_semisl.yaml b/src/otx/recipe/semantic_segmentation/semisl/segnext_t_semisl.yaml new file mode 100644 index 00000000000..6152665bb0a --- /dev/null +++ b/src/otx/recipe/semantic_segmentation/semisl/segnext_t_semisl.yaml @@ -0,0 +1,47 @@ +model: + class_path: otx.algo.segmentation.segnext.SemiSLSegNext + init_args: + label_info: 2 + name_base_model: SegNextT + + criterion_configuration: + - type: CrossEntropyLoss + params: + ignore_index: 255 + + optimizer: + class_path: torch.optim.AdamW + init_args: + lr: 0.00006 + betas: + - 0.9 + - 0.999 + weight_decay: 0.01 + + scheduler: + class_path: otx.core.schedulers.LinearWarmupSchedulerCallable + init_args: + num_warmup_steps: 20 + main_scheduler_callable: + class_path: torch.optim.lr_scheduler.PolynomialLR + init_args: + total_iters: 100 + power: 0.9 + last_epoch: -1 + +engine: + task: SEMANTIC_SEGMENTATION + device: auto + +callback_monitor: val/Dice + +data: ../../_base_/data/semisl/semantic_segmentation_semisl.yaml +overrides: + callbacks: + - class_path: otx.algo.callbacks.adaptive_early_stopping.EarlyStoppingWithWarmup + init_args: + warmup_iters: 100 + - class_path: otx.algo.callbacks.ema_mean_teacher.EMAMeanTeacher + init_args: + momentum: 0.99 + start_epoch: 2 diff --git a/tests/assets/common_semantic_segmentation_dataset/unlabeled/0001.png b/tests/assets/common_semantic_segmentation_dataset/unlabeled/0001.png new file mode 100644 index 00000000000..058399a1726 Binary files /dev/null and b/tests/assets/common_semantic_segmentation_dataset/unlabeled/0001.png differ diff --git a/tests/assets/common_semantic_segmentation_dataset/unlabeled/0002.png b/tests/assets/common_semantic_segmentation_dataset/unlabeled/0002.png new file mode 100644 index 00000000000..058399a1726 Binary files /dev/null and b/tests/assets/common_semantic_segmentation_dataset/unlabeled/0002.png differ diff --git a/tests/assets/common_semantic_segmentation_dataset/unlabeled/0003.png b/tests/assets/common_semantic_segmentation_dataset/unlabeled/0003.png new file mode 100644 index 00000000000..058399a1726 Binary files /dev/null and b/tests/assets/common_semantic_segmentation_dataset/unlabeled/0003.png differ diff --git a/tests/assets/common_semantic_segmentation_dataset/unlabeled/0004.png b/tests/assets/common_semantic_segmentation_dataset/unlabeled/0004.png new file mode 100644 index 00000000000..02d1fe884f7 Binary files /dev/null and b/tests/assets/common_semantic_segmentation_dataset/unlabeled/0004.png differ diff --git a/tests/assets/common_semantic_segmentation_dataset/unlabeled/0005.png b/tests/assets/common_semantic_segmentation_dataset/unlabeled/0005.png new file mode 100644 index 00000000000..02d1fe884f7 Binary files /dev/null and b/tests/assets/common_semantic_segmentation_dataset/unlabeled/0005.png differ diff --git a/tests/integration/api/test_xai.py b/tests/integration/api/test_xai.py index e668bb077df..5471fc8d8dd 100644 --- a/tests/integration/api/test_xai.py +++ b/tests/integration/api/test_xai.py @@ -41,13 +41,14 @@ def test_forward_explain( Returns: None """ - task = recipe.split("/")[-2] - model_name = recipe.split("/")[-1].split(".")[0] + + recipe_split = recipe.split("/") + model_name = recipe_split[-1].split(".")[0] + is_semisl = model_name.endswith("_semisl") + task = recipe_split[-2] if not is_semisl else recipe_split[-3] if "dino" in model_name: pytest.skip("DINO is not supported.") - if "_semisl" in model_name: - pytest.skip("Semi-SL is not supported.") if "maskrcnn_r50_tv" in model_name: pytest.skip("MaskRCNN R50 Torchvision model doesn't support explain.") @@ -101,8 +102,10 @@ def test_predict_with_explain( Returns: None """ - task = recipe.split("/")[-2] - model_name = recipe.split("/")[-1].split(".")[0] + recipe_split = recipe.split("/") + model_name = recipe_split[-1].split(".")[0] + is_semisl = model_name.endswith("_semisl") + task = recipe_split[-2] if not is_semisl else recipe_split[-3] if "dino" in model_name: pytest.skip("DINO is not supported.") diff --git a/tests/integration/cli/test_cli.py b/tests/integration/cli/test_cli.py index 73908782e98..4f7fbe2a74f 100644 --- a/tests/integration/cli/test_cli.py +++ b/tests/integration/cli/test_cli.py @@ -27,8 +27,10 @@ def fxt_trained_model( tmp_path, ): recipe = request.param - task = recipe.split("/")[-2] - model_name = recipe.split("/")[-1].split(".")[0] + recipe_split = recipe.split("/") + model_name = recipe_split[-1].split(".")[0] + is_semisl = model_name.endswith("_semisl") + task = recipe_split[-2] if not is_semisl else recipe_split[-3] # 1) otx train tmp_path_train = tmp_path / f"otx_train_{model_name}" @@ -48,11 +50,11 @@ def fxt_trained_model( *fxt_cli_override_command_per_task[task], ] - if model_name.endswith("_semisl") and "multi_class_cls" in recipe: + if is_semisl: command_cfg.extend( [ "--data.unlabeled_subset.data_root", - fxt_target_dataset_per_task["multi_class_cls_semisl"], + fxt_target_dataset_per_task[f"{task}_semisl"], ], ) @@ -434,9 +436,8 @@ def test_otx_ov_test( assert len(metric_result) > 0 -@pytest.mark.parametrize("recipe", pytest.RECIPE_LIST, ids=lambda x: "/".join(Path(x).parts[-2:])) def test_otx_hpo_e2e( - recipe: str, + fxt_trained_model, tmp_path: Path, fxt_accelerator: str, fxt_target_dataset_per_task: dict, @@ -453,13 +454,14 @@ def test_otx_hpo_e2e( Returns: None """ - task = recipe.split("/")[-2] - model_name = recipe.split("/")[-1].split(".")[0] + recipe, task, model_name, _ = fxt_trained_model if task.upper() == OTXTaskType.ZERO_SHOT_VISUAL_PROMPTING: pytest.skip("ZERO_SHOT_VISUAL_PROMPTING doesn't support HPO.") if "padim" in recipe: pytest.skip("padim model doesn't support HPO.") + if model_name.endswith("_semisl"): + pytest.skip("Semi-supervised learning model doesn't support HPO.") tmp_path_hpo = tmp_path / f"otx_hpo_{model_name}" tmp_path_hpo.mkdir(parents=True) diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 954c4357a50..ead1117c6dd 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -124,6 +124,7 @@ def fxt_target_dataset_per_task() -> dict: "rotated_detection": "tests/assets/car_tree_bug", "instance_segmentation": "tests/assets/car_tree_bug", "semantic_segmentation": "tests/assets/common_semantic_segmentation_dataset/supervised", + "semantic_segmentation_semisl": "tests/assets/common_semantic_segmentation_dataset/unlabeled", "action_classification": "tests/assets/action_classification_dataset/", "visual_prompting": "tests/assets/car_tree_bug", "zero_shot_visual_prompting": "tests/assets/car_tree_bug_zero_shot", diff --git a/tests/perf/test_semantic_segmentation.py b/tests/perf/test_semantic_segmentation.py index abc459d3359..c395c0a52c1 100644 --- a/tests/perf/test_semantic_segmentation.py +++ b/tests/perf/test_semantic_segmentation.py @@ -96,3 +96,81 @@ def test_perf( benchmark=fxt_benchmark, criteria=self.BENCHMARK_CRITERIA, ) + + +class TestPerfSemanticSegmentationSemiSL(TestPerfSemanticSegmentation): + """Benchmark semantic segmentation.""" + + MODEL_TEST_CASES = [ # noqa: RUF012 + Benchmark.Model(task="semantic_segmentation", name="litehrnet_18_semisl", category="balance"), + Benchmark.Model(task="semantic_segmentation", name="litehrnet_s_semisl", category="speed"), + Benchmark.Model(task="semantic_segmentation", name="litehrnet_x_semisl", category="accuracy"), + Benchmark.Model(task="semantic_segmentation", name="segnext_b_semisl", category="other"), + Benchmark.Model(task="semantic_segmentation", name="segnext_s_semisl", category="other"), + Benchmark.Model(task="semantic_segmentation", name="segnext_t_semisl", category="other"), + Benchmark.Model(task="semantic_segmentation", name="dino_v2_semisl", category="other"), + ] + + DATASET_TEST_CASES = [ # noqa: RUF012 + Benchmark.Dataset( + name="kvasir", + path=Path("semantic_seg/semisl/kvasir_24"), + group="small", + num_repeat=5, + unlabeled_data_path=Path("semantic_seg/semisl/unlabeled_images/kvasir"), + extra_overrides={}, + ), + Benchmark.Dataset( + name="kitti", + path=Path("semantic_seg/semisl/kitti_18"), + group="small", + num_repeat=5, + unlabeled_data_path=Path("semantic_seg/semisl/unlabeled_images/kitti"), + extra_overrides={}, + ), + Benchmark.Dataset( + name="cityscapes", + path=Path("semantic_seg/semisl/cityscapes"), + group="medium", + num_repeat=5, + unlabeled_data_path=Path("semantic_seg/semisl/unlabeled_images/cityscapes"), + extra_overrides={}, + ), + Benchmark.Dataset( + name="pascal_voc", + path=Path("semantic_seg/semisl/pascal_voc"), + group="large", + num_repeat=5, + unlabeled_data_path=Path("semantic_seg/semisl/unlabeled_images/pascal_voc"), + extra_overrides={}, + ), + ] + + @pytest.mark.parametrize( + "fxt_model", + MODEL_TEST_CASES, + ids=lambda model: model.name, + indirect=True, + ) + @pytest.mark.parametrize( + "fxt_dataset", + DATASET_TEST_CASES, + ids=lambda dataset: dataset.name, + indirect=True, + ) + def test_perf( + self, + fxt_model: Benchmark.Model, + fxt_dataset: Benchmark.Dataset, + fxt_benchmark: Benchmark, + fxt_accelerator: str, + ): + if fxt_model.name == "dino_v2" and fxt_accelerator == "xpu": + pytest.skip(f"{fxt_model.name} doesn't support {fxt_accelerator}.") + + self._test_perf( + model=fxt_model, + dataset=fxt_dataset, + benchmark=fxt_benchmark, + criteria=self.BENCHMARK_CRITERIA, + ) diff --git a/tests/unit/algo/callbacks/test_ema_mean_teacher.py b/tests/unit/algo/callbacks/test_ema_mean_teacher.py new file mode 100644 index 00000000000..c68a899970a --- /dev/null +++ b/tests/unit/algo/callbacks/test_ema_mean_teacher.py @@ -0,0 +1,68 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +from unittest.mock import MagicMock, patch + +import pytest +from otx.algo.callbacks.ema_mean_teacher import EMAMeanTeacher + + +class TestEMAMeanTeacher: + @pytest.fixture() + def ema_mean_teacher(self): + return EMAMeanTeacher(momentum=0.99, start_epoch=1) + + def test_initialization(self, ema_mean_teacher): + assert ema_mean_teacher.momentum == 0.99 + assert ema_mean_teacher.start_epoch == 1 + assert not ema_mean_teacher.synced_models + + @patch("otx.algo.callbacks.ema_mean_teacher.Trainer") + @patch("otx.algo.callbacks.ema_mean_teacher.LightningModule") + def test_on_train_start(self, mock_trainer, mock_pl_module, ema_mean_teacher): + mock_model = MagicMock() + mock_model.student_model = MagicMock() + mock_model.teacher_model = MagicMock() + mock_trainer.model.model = mock_model + + ema_mean_teacher.on_train_start(mock_trainer, mock_pl_module) + + assert ema_mean_teacher.src_model is not None + assert ema_mean_teacher.dst_model is not None + assert ema_mean_teacher.src_model == mock_model.student_model + assert ema_mean_teacher.dst_model == mock_model.teacher_model + + @patch("otx.algo.callbacks.ema_mean_teacher.Trainer") + @patch("otx.algo.callbacks.ema_mean_teacher.LightningModule") + def test_on_train_batch_end(self, mock_trainer, mock_pl_module, ema_mean_teacher): + mock_trainer.current_epoch = 2 + mock_trainer.global_step = 10 + + ema_mean_teacher.src_params = {"param": MagicMock(requires_grad=True)} + ema_mean_teacher.dst_params = {"param": MagicMock(requires_grad=True)} + + ema_mean_teacher.on_train_batch_end(mock_trainer, mock_pl_module, None, None, None) + assert ema_mean_teacher.synced_models is True + assert ema_mean_teacher.dst_params["param"].data.copy_.call_count == 2 # 1 for copy and 1 for ema + + def test_copy_model(self, ema_mean_teacher): + src_param = MagicMock(requires_grad=True) + dst_param = MagicMock(requires_grad=True) + ema_mean_teacher.src_params = {"param": src_param} + ema_mean_teacher.dst_params = {"param": dst_param} + + ema_mean_teacher._copy_model() + + dst_param.data.copy_.assert_called_once_with(src_param.data) + + def test_ema_model(self, ema_mean_teacher): + src_param = MagicMock(requires_grad=True) + dst_param = MagicMock(requires_grad=True) + ema_mean_teacher.src_params = {"param": src_param} + ema_mean_teacher.dst_params = {"param": dst_param} + ema_mean_teacher.synced_models = True + ema_mean_teacher._ema_model(global_step=10) + + momentum = min(1 - 1 / (10 + 1), ema_mean_teacher.momentum) + expected_value = dst_param.data * momentum + src_param.data * (1 - momentum) + dst_param.data.copy_.assert_called_once_with(expected_value) diff --git a/tests/unit/algo/segmentation/segmentors/test_mean_teacher.py b/tests/unit/algo/segmentation/segmentors/test_mean_teacher.py new file mode 100644 index 00000000000..47ce0c32eb1 --- /dev/null +++ b/tests/unit/algo/segmentation/segmentors/test_mean_teacher.py @@ -0,0 +1,80 @@ +import pytest +import torch +from otx.algo.segmentation.segmentors import BaseSegmModel, MeanTeacher +from otx.core.data.entity.base import ImageInfo +from torch import nn + + +class TestMeanTeacher: + @pytest.fixture() + def model(self): + decode_head = nn.Conv2d(3, 2, 1) + decode_head.num_classes = 2 + model = BaseSegmModel( + backbone=nn.Sequential(nn.Conv2d(3, 5, 1), nn.ReLU(), nn.Conv2d(5, 3, 1)), + decode_head=decode_head, + ) + return MeanTeacher(model) + + @pytest.fixture() + def inputs(self): + return torch.randn(4, 3, 10, 10) + + @pytest.fixture() + def unlabeled_weak_images(self): + return torch.randn(4, 3, 10, 10) + + @pytest.fixture() + def unlabeled_strong_images(self): + return torch.randn(4, 3, 10, 10) + + @pytest.fixture() + def img_metas(self): + return [ImageInfo(img_idx=i, img_shape=(10, 10), ori_shape=(10, 10)) for i in range(4)] + + @pytest.fixture() + def masks(self): + return torch.randint(0, 2, size=(4, 10, 10)).long() + + def test_forward_labeled_images(self, model, inputs, img_metas): + output = model.forward(inputs, img_metas, mode="tensor") + assert output.shape == (4, 2, 10, 10) + + def test_forward_unlabeled_images( + self, + model, + inputs, + unlabeled_weak_images, + unlabeled_strong_images, + img_metas, + masks, + ): + output = model.forward( + inputs, + unlabeled_weak_images, + unlabeled_strong_images, + img_metas=img_metas, + unlabeled_img_metas=img_metas, + global_step=1, + steps_per_epoch=1, + masks=masks, + mode="loss", + ) + assert isinstance(output, dict) + assert "loss_ce_ignore" in output + assert "loss_ce_ignore_unlabeled" in output + assert isinstance(output["loss_ce_ignore"], torch.Tensor) + assert isinstance(output["loss_ce_ignore_unlabeled"], torch.Tensor) + assert output["loss_ce_ignore"] > 0 + assert output["loss_ce_ignore_unlabeled"] > 0 + + def test_generate_pseudo_labels(self, model, unlabeled_weak_images): + pl_from_teacher, reweight_unsup = model._generate_pseudo_labels( + unlabeled_weak_images, + percent_unreliable=20, + ) + + assert isinstance(pl_from_teacher, torch.Tensor) + assert pl_from_teacher.shape == (4, 1, 10, 10) + assert isinstance(reweight_unsup, torch.Tensor) + assert isinstance(reweight_unsup.item(), float) diff --git a/tests/unit/engine/utils/test_auto_configurator.py b/tests/unit/engine/utils/test_auto_configurator.py index 078c81fc84c..4ba0502c391 100644 --- a/tests/unit/engine/utils/test_auto_configurator.py +++ b/tests/unit/engine/utils/test_auto_configurator.py @@ -25,7 +25,7 @@ def fxt_data_root_per_task_type() -> dict: "tests/assets/classification_dataset": OTXTaskType.MULTI_CLASS_CLS, "tests/assets/multilabel_classification": OTXTaskType.MULTI_LABEL_CLS, "tests/assets/car_tree_bug": OTXTaskType.DETECTION, - "tests/assets/common_semantic_segmentation_dataset": OTXTaskType.SEMANTIC_SEGMENTATION, + "tests/assets/common_semantic_segmentation_dataset/supervised": OTXTaskType.SEMANTIC_SEGMENTATION, }