diff --git a/pyproject.toml b/pyproject.toml index 0735ee5f637..5fd32b9ea42 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -79,7 +79,7 @@ xpu = [ "intel-extension-for-pytorch==2.1.30+xpu", "oneccl_bind_pt==2.1.300+xpu", "lightning==2.2", - "pytorchcv", + "pytorchcv==0.0.67", "timm==1.0.3", "openvino==2024.3", "openvino-dev==2024.3", @@ -93,7 +93,7 @@ xpu = [ base = [ "torch==2.2.2", "lightning==2.3.3", - "pytorchcv", + "pytorchcv==0.0.67", "timm==1.0.3", "openvino==2024.3", "openvino-dev==2024.3", diff --git a/src/otx/algo/classification/torchvision_model.py b/src/otx/algo/classification/torchvision_model.py index 198e69f9e40..fd4a9e6f192 100644 --- a/src/otx/algo/classification/torchvision_model.py +++ b/src/otx/algo/classification/torchvision_model.py @@ -429,7 +429,6 @@ def __init__( ) -> None: self.backbone = backbone self.freeze_backbone = freeze_backbone - self.train_type = train_type self.task = task # TODO(@harimkang): Need to make it configurable. @@ -447,6 +446,7 @@ def __init__( metric=metric, torch_compile=torch_compile, input_size=input_size, + train_type=train_type, ) self.input_size: tuple[int, int] diff --git a/src/otx/algo/segmentation/backbones/__init__.py b/src/otx/algo/segmentation/backbones/__init__.py index 1c7a4398551..4c7a44cee9b 100644 --- a/src/otx/algo/segmentation/backbones/__init__.py +++ b/src/otx/algo/segmentation/backbones/__init__.py @@ -4,7 +4,7 @@ """Backbone modules for OTX segmentation model.""" from .dinov2 import DinoVisionTransformer -from .litehrnet import LiteHRNet +from .litehrnet import LiteHRNetBackbone from .mscan import MSCAN -__all__ = ["LiteHRNet", "DinoVisionTransformer", "MSCAN"] +__all__ = ["LiteHRNetBackbone", "DinoVisionTransformer", "MSCAN"] diff --git a/src/otx/algo/segmentation/backbones/dinov2.py b/src/otx/algo/segmentation/backbones/dinov2.py index 6abf733165a..5468870ffef 100644 --- a/src/otx/algo/segmentation/backbones/dinov2.py +++ b/src/otx/algo/segmentation/backbones/dinov2.py @@ -13,14 +13,13 @@ import torch from torch import nn -from otx.algo.modules.base_module import BaseModule from otx.algo.utils.mmengine_utils import load_checkpoint_to_model, load_from_http from otx.utils.utils import get_class_initial_arguments logger = logging.getLogger() -class DinoVisionTransformer(BaseModule): +class DinoVisionTransformer(nn.Module): """DINO-v2 Model.""" def __init__( @@ -28,10 +27,9 @@ def __init__( name: str, freeze_backbone: bool, out_index: list[int], - init_cfg: dict | None = None, pretrained_weights: str | None = None, ): - super().__init__(init_cfg) + super().__init__() self._init_args = get_class_initial_arguments() ci_data_root = os.environ.get("CI_DATA_ROOT") diff --git a/src/otx/algo/segmentation/backbones/litehrnet.py b/src/otx/algo/segmentation/backbones/litehrnet.py index ba98a8b4650..8520b38db69 100644 --- a/src/otx/algo/segmentation/backbones/litehrnet.py +++ b/src/otx/algo/segmentation/backbones/litehrnet.py @@ -10,7 +10,7 @@ from __future__ import annotations from pathlib import Path -from typing import Callable +from typing import Any, Callable, ClassVar import torch import torch.utils.checkpoint as cp @@ -18,11 +18,7 @@ from torch.nn import functional from otx.algo.modules import Conv2dModule, build_norm_layer -from otx.algo.modules.base_module import BaseModule from otx.algo.segmentation.modules import ( - AsymmetricPositionAttentionModule, - IterativeAggregator, - LocalAttentionModule, channel_shuffle, ) from otx.algo.utils.mmengine_utils import load_checkpoint_to_model, load_from_http @@ -1191,7 +1187,7 @@ def forward(self, x: torch.Tensor) -> list[torch.Tensor]: return out -class LiteHRNet(BaseModule): +class NNLiteHRNet(nn.Module): """Lite-HRNet backbone. `High-Resolution Representations for Labeling Pixels and Regions @@ -1212,44 +1208,34 @@ class LiteHRNet(BaseModule): def __init__( self, - extra: dict, + stem: dict[str, Any], + num_stages: int, + stages_spec: dict[str, Any], in_channels: int = 3, - norm_cfg: dict | None = None, + norm_cfg: dict[str, Any] | None = None, norm_eval: bool = False, with_cp: bool = False, zero_init_residual: bool = False, dropout: float | None = None, - init_cfg: dict | None = None, pretrained_weights: str | None = None, ) -> None: """Init.""" - super().__init__(init_cfg=init_cfg) + super().__init__() if norm_cfg is None: - norm_cfg = {"type": "BN"} + norm_cfg = {"type": "BN", "requires_grad": True} - self.extra = extra self.norm_cfg = norm_cfg self.norm_eval = norm_eval self.with_cp = with_cp self.zero_init_residual = zero_init_residual self.stem = Stem( in_channels, - input_norm=self.extra["stem"]["input_norm"], - stem_channels=self.extra["stem"]["stem_channels"], - out_channels=self.extra["stem"]["out_channels"], - expand_ratio=self.extra["stem"]["expand_ratio"], - strides=self.extra["stem"]["strides"], - extra_stride=self.extra["stem"]["extra_stride"], norm_cfg=self.norm_cfg, + **stem, ) - - self.enable_stem_pool = self.extra["stem"].get("out_pool", False) - if self.enable_stem_pool: - self.stem_pool = nn.AvgPool2d(kernel_size=3, stride=2) - - self.num_stages = self.extra["num_stages"] - self.stages_spec = self.extra["stages_spec"] + self.num_stages = num_stages + self.stages_spec = stages_spec num_channels_last = [ self.stem.out_channels, @@ -1273,80 +1259,6 @@ def __init__( ) setattr(self, f"stage{i}", stage) - self.out_modules = None - if self.extra.get("out_modules") is not None: - out_modules = [] - in_modules_channels, out_modules_channels = num_channels_last[-1], None - if self.extra["out_modules"]["conv"]["enable"]: - out_modules_channels = self.extra["out_modules"]["conv"]["channels"] - out_modules.append( - Conv2dModule( - in_channels=in_modules_channels, - out_channels=out_modules_channels, - kernel_size=1, - stride=1, - padding=0, - norm_cfg=self.norm_cfg, - activation_callable=nn.ReLU, - ), - ) - in_modules_channels = out_modules_channels - if self.extra["out_modules"]["position_att"]["enable"]: - out_modules.append( - AsymmetricPositionAttentionModule( - in_channels=in_modules_channels, - key_channels=self.extra["out_modules"]["position_att"]["key_channels"], - value_channels=self.extra["out_modules"]["position_att"]["value_channels"], - psp_size=self.extra["out_modules"]["position_att"]["psp_size"], - norm_cfg=self.norm_cfg, - ), - ) - if self.extra["out_modules"]["local_att"]["enable"]: - out_modules.append( - LocalAttentionModule( - num_channels=in_modules_channels, - norm_cfg=self.norm_cfg, - ), - ) - - if len(out_modules) > 0: - self.out_modules = nn.Sequential(*out_modules) - num_channels_last.append(in_modules_channels) - - self.add_stem_features = self.extra.get("add_stem_features", False) - if self.add_stem_features: - self.stem_transition = nn.Sequential( - Conv2dModule( - self.stem.out_channels, - self.stem.out_channels, - kernel_size=3, - stride=1, - padding=1, - groups=self.stem.out_channels, - norm_cfg=norm_cfg, - activation_callable=None, - ), - Conv2dModule( - self.stem.out_channels, - num_channels_last[0], - kernel_size=1, - stride=1, - padding=0, - norm_cfg=norm_cfg, - activation_callable=nn.ReLU, - ), - ) - - num_channels_last = [num_channels_last[0], *num_channels_last] - - self.with_aggregator = self.extra.get("out_aggregator") and self.extra["out_aggregator"]["enable"] - if self.with_aggregator: - self.aggregator = IterativeAggregator( - in_channels=num_channels_last, - min_channels=self.extra["out_aggregator"].get("min_channels", None), - norm_cfg=self.norm_cfg, - ) - if pretrained_weights is not None: self.load_pretrained_weights(pretrained_weights, prefix="backbone") @@ -1479,11 +1391,7 @@ def _make_stage( def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward function.""" stem_outputs = self.stem(x) - y_x2 = y_x4 = stem_outputs - y = y_x4 - - if self.enable_stem_pool: - y = self.stem_pool(y) + y = stem_outputs y_list = [y] for i in range(self.num_stages): @@ -1502,21 +1410,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: stage_module = getattr(self, f"stage{i}") y_list = stage_module(stage_inputs) - if self.out_modules is not None: - y_list.append(self.out_modules(y_list[-1])) - - if self.add_stem_features: - y_stem = self.stem_transition(y_x2) - y_list = [y_stem, *y_list] - - out = y_list - if self.with_aggregator: - out = self.aggregator(out) - - if self.extra.get("add_input", False): - out = [x, *out] - - return out + return y_list def load_pretrained_weights(self, pretrained: str | None = None, prefix: str = "") -> None: """Initialize weights.""" @@ -1530,3 +1424,82 @@ def load_pretrained_weights(self, pretrained: str | None = None, prefix: str = " print(f"init weight - {pretrained}") if checkpoint is not None: load_checkpoint_to_model(self, checkpoint, prefix=prefix) + + +class LiteHRNetBackbone: + """LiteHRNet backbone factory.""" + + LITEHRNET_CFG: ClassVar[dict[str, Any]] = { + "lite_hrnet_s": { + "stem": { + "stem_channels": 32, + "out_channels": 32, + "expand_ratio": 1, + "strides": [2, 2], + "extra_stride": True, + "input_norm": False, + }, + "num_stages": 2, + "stages_spec": { + "num_modules": [4, 4], + "num_branches": [2, 3], + "num_blocks": [2, 2], + "module_type": ["LITE", "LITE"], + "with_fuse": [True, True], + "reduce_ratios": [8, 8], + "num_channels": [[60, 120], [60, 120, 240]], + }, + "pretrained_weights": "https://storage.openvinotoolkit.org/repositories/openvino_training_extensions/models/custom_semantic_segmentation/litehrnetsv2_imagenet1k_rsc.pth", + }, + "lite_hrnet_18": { + "stem": { + "stem_channels": 32, + "out_channels": 32, + "expand_ratio": 1, + "strides": [2, 2], + "extra_stride": False, + "input_norm": False, + }, + "num_stages": 3, + "stages_spec": { + "num_modules": [2, 4, 2], + "num_branches": [2, 3, 4], + "num_blocks": [2, 2, 2], + "module_type": ["LITE", "LITE", "LITE"], + "with_fuse": [True, True, True], + "reduce_ratios": [8, 8, 8], + "num_channels": [[40, 80], [40, 80, 160], [40, 80, 160, 320]], + }, + "pretrained_weights": "https://storage.openvinotoolkit.org/repositories/openvino_training_extensions/models/custom_semantic_segmentation/litehrnet18_imagenet1k_rsc.pth", + }, + "lite_hrnet_x": { + "stem": { + "stem_channels": 60, + "out_channels": 60, + "expand_ratio": 1, + "strides": [2, 1], + "extra_stride": False, + "input_norm": False, + }, + "num_stages": 4, + "stages_spec": { + "weighting_module_version": "v1", + "num_modules": [2, 4, 4, 2], + "num_branches": [2, 3, 4, 5], + "num_blocks": [2, 2, 2, 2], + "module_type": ["LITE", "LITE", "LITE", "LITE"], + "with_fuse": [True, True, True, True], + "reduce_ratios": [2, 4, 8, 8], + "num_channels": [[18, 60], [18, 60, 80], [18, 60, 80, 160], [18, 60, 80, 160, 320]], + }, + "pretrained_weights": "https://storage.openvinotoolkit.org/repositories/openvino_training_extensions/models/custom_semantic_segmentation/litehrnetxv3_imagenet1k_rsc.pth", + }, + } + + def __new__(cls, version: str) -> NNLiteHRNet: + """Constructor for LiteHRNet backbone.""" + if version not in cls.LITEHRNET_CFG: + msg = f"model type '{version}' is not supported" + raise KeyError(msg) + + return NNLiteHRNet(**cls.LITEHRNET_CFG[version]) diff --git a/src/otx/algo/segmentation/backbones/mscan.py b/src/otx/algo/segmentation/backbones/mscan.py index cc1bb96db8b..7226fb4a403 100644 --- a/src/otx/algo/segmentation/backbones/mscan.py +++ b/src/otx/algo/segmentation/backbones/mscan.py @@ -6,7 +6,7 @@ from __future__ import annotations from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable +from typing import TYPE_CHECKING, Any, Callable, ClassVar import torch from torch import nn @@ -329,7 +329,7 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, int, int]: return x, h, w -class MSCAN(BaseModule): +class NNMSCAN(nn.Module): """SegNeXt Multi-Scale Convolutional Attention Network (MCSAN) backbone. This backbone is the implementation of `SegNeXt: Rethinking @@ -360,23 +360,22 @@ class MSCAN(BaseModule): def __init__( self, in_channels: int = 3, - embed_dims: list[int] = [64, 128, 256, 512], # noqa: B006 - mlp_ratios: list[int] = [4, 4, 4, 4], # noqa: B006 + embed_dims: list[int] = [64, 128, 320, 512], # noqa: B006 + mlp_ratios: list[int] = [8, 8, 4, 4], # noqa: B006 drop_rate: float = 0.0, - drop_path_rate: float = 0.0, + drop_path_rate: float = 0.1, depths: list[int] = [3, 4, 6, 3], # noqa: B006 num_stages: int = 4, attention_kernel_sizes: list[int | list[int]] = [5, [1, 7], [1, 11], [1, 21]], # noqa: B006 attention_kernel_paddings: list[int | list[int]] = [2, [0, 3], [0, 5], [0, 10]], # noqa: B006 activation_callable: Callable[..., nn.Module] = nn.GELU, norm_cfg: dict[str, str | bool] | None = None, - init_cfg: dict[str, str] | list[dict[str, str]] | None = None, pretrained_weights: str | None = None, ) -> None: """Initialize a MSCAN backbone.""" - super().__init__(init_cfg=init_cfg) + super().__init__() if norm_cfg is None: - norm_cfg = {"type": "SyncBN", "requires_grad": True} + norm_cfg = {"type": "BN", "requires_grad": True} self.depths = depths self.num_stages = num_stages @@ -450,3 +449,31 @@ def load_pretrained_weights(self, pretrained: str | None = None, prefix: str = " print(f"init weight - {pretrained}") if checkpoint is not None: load_checkpoint_to_model(self, checkpoint, prefix=prefix) + + +class MSCAN: + """MSCAN backbone factory.""" + + MSCAN_CFG: ClassVar[dict[str, Any]] = { + "segnext_tiny": { + "depths": [3, 3, 5, 2], + "embed_dims": [32, 64, 160, 256], + "pretrained_weights": "https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/segnext/mscan_t_20230227-119e8c9f.pth", + }, + "segnext_small": { + "depths": [2, 2, 4, 2], + "pretrained_weights": "https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/segnext/mscan_s_20230227-f33ccdf2.pth", + }, + "segnext_base": { + "depths": [3, 3, 12, 3], + "pretrained_weights": "https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/segnext/mscan_b_20230227-3ab7d230.pth", + }, + } + + def __new__(cls, version: str) -> NNMSCAN: + """Constructor for MSCAN backbone.""" + if version not in cls.MSCAN_CFG: + msg = f"model type '{version}' is not supported" + raise KeyError(msg) + + return NNMSCAN(**cls.MSCAN_CFG[version]) diff --git a/src/otx/algo/segmentation/dino_v2_seg.py b/src/otx/algo/segmentation/dino_v2_seg.py index 28c6048f023..ab79104509a 100644 --- a/src/otx/algo/segmentation/dino_v2_seg.py +++ b/src/otx/algo/segmentation/dino_v2_seg.py @@ -7,92 +7,37 @@ from typing import TYPE_CHECKING, Any, ClassVar -import torch - from otx.algo.segmentation.backbones import DinoVisionTransformer from otx.algo.segmentation.heads import FCNHead -from otx.algo.segmentation.segmentors import BaseSegmModel, MeanTeacher -from otx.core.data.entity.segmentation import SegBatchDataEntity -from otx.core.metrics.dice import SegmCallable -from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable -from otx.core.model.segmentation import TorchVisionCompatibleModel +from otx.algo.segmentation.losses import CrossEntropyLossWithIgnore +from otx.algo.segmentation.segmentors import BaseSegmModel +from otx.core.model.segmentation import OTXSegmentationModel if TYPE_CHECKING: - from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable from torch import nn from typing_extensions import Self - from otx.core.metrics import MetricCallable - from otx.core.schedulers import LRSchedulerListCallable - from otx.core.types.label import LabelInfoTypes - -class DinoV2Seg(BaseSegmModel): +class DinoV2Seg(OTXSegmentationModel): """DinoV2Seg Model.""" - default_backbone_configuration: ClassVar[dict[str, Any]] = { - "name": "dinov2_vits14", - "freeze_backbone": True, - "out_index": [8, 9, 10, 11], - } - default_decode_head_configuration: ClassVar[dict[str, Any]] = { - "norm_cfg": {"type": "SyncBN", "requires_grad": True}, - "in_channels": [384, 384, 384, 384], - "in_index": [0, 1, 2, 3], - "input_transform": "resize_concat", - "channels": 1536, - "kernel_size": 1, - "num_convs": 1, - "concat_input": False, - "dropout_ratio": -1, - "align_corners": False, - "pretrained_weights": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_ade20k_linear_head.pth", - } - - -class OTXDinoV2Seg(TorchVisionCompatibleModel): - """DinoV2Seg Model.""" + AVAILABLE_MODEL_VERSIONS: ClassVar[list[str]] = [ + "dinov2_vits14", + ] - input_size_multiplier = 14 + def _build_model(self) -> nn.Module: + if self.model_version not in self.AVAILABLE_MODEL_VERSIONS: + msg = f"Model version {self.model_version} is not supported." + raise ValueError(msg) - def __init__( - self, - label_info: LabelInfoTypes, - input_size: tuple[int, int] = (560, 560), - optimizer: OptimizerCallable = DefaultOptimizerCallable, - scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable, - metric: MetricCallable = SegmCallable, # type: ignore[assignment] - torch_compile: bool = False, - backbone_configuration: dict[str, Any] | None = None, - decode_head_configuration: dict[str, Any] | None = None, - criterion_configuration: list[dict[str, Any]] | None = None, - export_image_configuration: dict[str, Any] | None = None, - name_base_model: str = "semantic_segmentation_model", - ): - super().__init__( - label_info=label_info, - input_size=input_size, - optimizer=optimizer, - scheduler=scheduler, - metric=metric, - torch_compile=torch_compile, - backbone_configuration=backbone_configuration, - decode_head_configuration=decode_head_configuration, - criterion_configuration=criterion_configuration, - export_image_configuration=export_image_configuration, - name_base_model=name_base_model, - ) + backbone = DinoVisionTransformer(name=self.model_version, freeze_backbone=True, out_index=[8, 9, 10, 11]) + decode_head = FCNHead(self.model_version, num_classes=self.num_classes) + criterion = CrossEntropyLossWithIgnore(ignore_index=self.label_info.ignore_index) # type: ignore[attr-defined] - def _create_model(self) -> nn.Module: - # merge configurations with defaults overriding them - backbone_configuration = DinoV2Seg.default_backbone_configuration | self.backbone_configuration - decode_head_configuration = DinoV2Seg.default_decode_head_configuration | self.decode_head_configuration - backbone = DinoVisionTransformer(**backbone_configuration) - decode_head = FCNHead(num_classes=self.num_classes, **decode_head_configuration) - return DinoV2Seg( + return BaseSegmModel( backbone=backbone, decode_head=decode_head, - criterion_configuration=self.criterion_configuration, + criterion=criterion, ) @property @@ -107,46 +52,3 @@ def to(self, *args, **kwargs) -> Self: msg = f"{type(self).__name__} doesn't support XPU." raise RuntimeError(msg) return ret - - -class DinoV2SegSemiSL(OTXDinoV2Seg): - """DinoV2SegSemiSL Model.""" - - def _customize_inputs(self, entity: SegBatchDataEntity) -> dict[str, Any]: - if not isinstance(entity, dict): - if self.training: - msg = "unlabeled inputs should be provided for semi-sl training" - raise RuntimeError(msg) - return super()._customize_inputs(entity) - - entity["labeled"].masks = torch.stack(entity["labeled"].masks).long() - w_u_images = entity["weak_transforms"].images - s_u_images = entity["strong_transforms"].images - unlabeled_img_metas = entity["weak_transforms"].imgs_info - labeled_inputs = entity["labeled"] - - return { - "inputs": labeled_inputs.images, - "unlabeled_weak_images": w_u_images, - "unlabeled_strong_images": s_u_images, - "global_step": self.trainer.global_step, - "steps_per_epoch": self.trainer.num_training_batches, - "img_metas": labeled_inputs.imgs_info, - "unlabeled_img_metas": unlabeled_img_metas, - "masks": labeled_inputs.masks, - "mode": "loss", - } - - def _create_model(self) -> nn.Module: - # merge configurations with defaults overriding them - backbone_configuration = DinoV2Seg.default_backbone_configuration | self.backbone_configuration - decode_head_configuration = DinoV2Seg.default_decode_head_configuration | self.decode_head_configuration - backbone = DinoVisionTransformer(**backbone_configuration) - decode_head = FCNHead(num_classes=self.num_classes, **decode_head_configuration) - base_model = DinoV2Seg( - backbone=backbone, - decode_head=decode_head, - criterion_configuration=self.criterion_configuration, - ) - - return MeanTeacher(base_model, unsup_weight=0.7, drop_unrel_pixels_percent=20, semisl_start_epoch=2) diff --git a/src/otx/algo/segmentation/heads/base_segm_head.py b/src/otx/algo/segmentation/heads/base_segm_head.py index 419fea64071..760fbb6e8ea 100644 --- a/src/otx/algo/segmentation/heads/base_segm_head.py +++ b/src/otx/algo/segmentation/heads/base_segm_head.py @@ -5,7 +5,7 @@ from __future__ import annotations -from abc import ABCMeta, abstractmethod +from abc import abstractmethod from pathlib import Path from typing import Callable @@ -16,7 +16,7 @@ from otx.algo.utils.mmengine_utils import load_checkpoint_to_model, load_from_http -class BaseSegmHead(nn.Module, metaclass=ABCMeta): +class BaseSegmHead(nn.Module): """Base class for segmentation heads. Args: @@ -45,9 +45,8 @@ def __init__( activation_callable: Callable[..., nn.Module] | None = nn.ReLU, in_index: int | list[int] = -1, input_transform: str | None = None, - ignore_index: int = 255, align_corners: bool = False, - pretrained_weights: str | None = None, + pretrained_weights: Path | str | None = None, ) -> None: """Initialize the BaseSegmHead.""" super().__init__() @@ -61,7 +60,6 @@ def __init__( msg = f'"in_index" expects a list, but got {type(in_index)}' raise TypeError(msg) self.in_index = in_index - self.ignore_index = ignore_index self.align_corners = align_corners if input_transform == "resize_concat": @@ -141,7 +139,7 @@ def cls_seg(self, feat: torch.Tensor) -> torch.Tensor: def load_pretrained_weights( self, - pretrained: str | None = None, + pretrained: Path | str | None = None, prefix: str = "", ) -> None: """Initialize weights. @@ -159,7 +157,15 @@ def load_pretrained_weights( checkpoint = torch.load(pretrained, map_location=torch.device("cpu")) print(f"Init weights - {pretrained}") elif pretrained is not None: - checkpoint = load_from_http(pretrained, "cpu") + cache_dir = Path.home() / ".cache" / "torch" / "hub" / "checkpoints" + if isinstance(pretrained, Path): + msg = "pretrained path doesn't exists" + raise ValueError(msg) + checkpoint = load_from_http( + filename=pretrained, + map_location="cpu", + model_dir=cache_dir, + ) print(f"Init weights - {pretrained}") if checkpoint is not None: load_checkpoint_to_model(self, checkpoint, prefix=prefix) diff --git a/src/otx/algo/segmentation/heads/fcn_head.py b/src/otx/algo/segmentation/heads/fcn_head.py index da79e2db239..5fce1064e75 100644 --- a/src/otx/algo/segmentation/heads/fcn_head.py +++ b/src/otx/algo/segmentation/heads/fcn_head.py @@ -5,7 +5,7 @@ from __future__ import annotations -from typing import Any +from typing import TYPE_CHECKING, Any, Callable, ClassVar import torch from torch import Tensor, nn @@ -15,8 +15,11 @@ from .base_segm_head import BaseSegmHead +if TYPE_CHECKING: + from pathlib import Path -class FCNHead(BaseSegmHead): + +class NNFCNHead(BaseSegmHead): """Fully Convolution Networks for Semantic Segmentation with aggregation. This head is implemented of `FCNNet `_. @@ -33,17 +36,22 @@ def __init__( self, in_channels: list[int] | int, in_index: list[int] | int, + channels: int, norm_cfg: dict[str, Any] | None = None, input_transform: str | None = None, - num_convs: int = 2, - kernel_size: int = 3, - concat_input: bool = True, + num_classes: int = 80, + num_convs: int = 1, + kernel_size: int = 1, + concat_input: bool = False, dilation: int = 1, enable_aggregator: bool = False, aggregator_min_channels: int = 0, aggregator_merge_norm: str | None = None, aggregator_use_concat: bool = False, - **kwargs: Any, # noqa: ANN401 + align_corners: bool = False, + dropout_ratio: float = -1, + activation_callable: Callable[..., nn.Module] | None = nn.ReLU, + pretrained_weights: Path | str | None = None, ) -> None: """Initialize a Fully Convolution Networks head. @@ -60,6 +68,8 @@ def __init__( if num_convs < 0 and dilation <= 0: msg = "num_convs and dilation should be larger than 0" raise ValueError(msg) + if norm_cfg is None: + norm_cfg = {"type": "BN", "requires_grad": True} self.num_convs = num_convs self.concat_input = concat_input @@ -91,7 +101,12 @@ def __init__( norm_cfg=norm_cfg, input_transform=input_transform, in_channels=in_channels, - **kwargs, + align_corners=align_corners, + dropout_ratio=dropout_ratio, + channels=channels, + num_classes=num_classes, + activation_callable=activation_callable, + pretrained_weights=pretrained_weights, ) self.aggregator = aggregator @@ -175,3 +190,52 @@ def _transform_inputs(self, inputs: list[Tensor]) -> Tensor | list: Tensor: The transformed inputs """ return self.aggregator(inputs)[0] if self.aggregator is not None else super()._transform_inputs(inputs) + + +class FCNHead: + """FCNHead factory for segmentation.""" + + FCNHEAD_CFG: ClassVar[dict[str, Any]] = { + "lite_hrnet_s": { + "in_channels": [60, 120, 240], + "in_index": [0, 1, 2], + "input_transform": "multiple_select", + "channels": 60, + "enable_aggregator": True, + "aggregator_merge_norm": "None", + "aggregator_use_concat": False, + }, + "lite_hrnet_18": { + "in_channels": [40, 80, 160, 320], + "in_index": [0, 1, 2, 3], + "input_transform": "multiple_select", + "channels": 40, + "enable_aggregator": True, + }, + "lite_hrnet_x": { + "in_channels": [18, 60, 80, 160, 320], + "in_index": [0, 1, 2, 3, 4], + "input_transform": "multiple_select", + "channels": 60, + "enable_aggregator": True, + "aggregator_min_channels": 60, + "aggregator_merge_norm": "None", + "aggregator_use_concat": False, + }, + "dinov2_vits14": { + "norm_cfg": {"type": "SyncBN", "requires_grad": True}, + "in_channels": [384, 384, 384, 384], + "in_index": [0, 1, 2, 3], + "input_transform": "resize_concat", + "channels": 1536, + "pretrained_weights": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_ade20k_linear_head.pth", + }, + } + + def __new__(cls, version: str, num_classes: int) -> NNFCNHead: + """Constructor for FCNHead.""" + if version not in cls.FCNHEAD_CFG: + msg = f"model type '{version}' is not supported" + raise KeyError(msg) + + return NNFCNHead(**cls.FCNHEAD_CFG[version], num_classes=num_classes) diff --git a/src/otx/algo/segmentation/heads/ham_head.py b/src/otx/algo/segmentation/heads/ham_head.py index cd079752a15..ddf0c8685ce 100644 --- a/src/otx/algo/segmentation/heads/ham_head.py +++ b/src/otx/algo/segmentation/heads/ham_head.py @@ -5,7 +5,7 @@ from __future__ import annotations -from typing import Any +from typing import TYPE_CHECKING, Any, Callable, ClassVar import torch import torch.nn.functional as f @@ -16,6 +16,9 @@ from .base_segm_head import BaseSegmHead +if TYPE_CHECKING: + from pathlib import Path + class Hamburger(nn.Module): """Hamburger Module. @@ -34,7 +37,6 @@ def __init__( ham_channels: int, ham_kwargs: dict[str, Any], norm_cfg: dict[str, Any] | None = None, - **kwargs: Any, # noqa: ANN401 ) -> None: """Initialize Hamburger Module. @@ -61,14 +63,23 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return f.relu(x + enjoy, inplace=True) -class LightHamHead(BaseSegmHead): +class NNLightHamHead(BaseSegmHead): """SegNeXt decode head.""" def __init__( self, + in_channels: int | list[int], + channels: int, + num_classes: int, + dropout_ratio: float = 0.1, + norm_cfg: dict[str, Any] | None = None, + activation_callable: Callable[..., nn.Module] | None = nn.ReLU, + in_index: int | list[int] = [1, 2, 3], # noqa: B006 + input_transform: str | None = "multiple_select", + align_corners: bool = False, + pretrained_weights: Path | str | None = None, ham_channels: int = 512, ham_kwargs: dict[str, Any] | None = None, - **kwargs: Any, # noqa: ANN401 ) -> None: """SegNeXt decode head. @@ -84,18 +95,36 @@ def __init__( Args: ham_channels (int): input channels for Hamburger. Defaults to 512. - ham_kwargs (Dict[str, Any]): kwagrs for Ham. Defaults to an empty dictionary. + ham_kwargs (Dict[str, Any] | None): kwagrs for Ham. + If None: {"md_r": 16, "md_s": 1, "eval_steps": 7, "train_steps": 6} will be used. Returns: None """ - super().__init__(input_transform="multiple_select", **kwargs) + if norm_cfg is None: + norm_cfg = {"num_groups": 32, "requires_grad": True, "type": "GN"} + + super().__init__( + input_transform=input_transform, + in_channels=in_channels, + channels=channels, + num_classes=num_classes, + dropout_ratio=dropout_ratio, + norm_cfg=norm_cfg, + activation_callable=activation_callable, + in_index=in_index, + align_corners=align_corners, + pretrained_weights=pretrained_weights, + ) + if not isinstance(self.in_channels, list): msg = f"Input channels type must be list, but got {type(self.in_channels)}" raise TypeError(msg) - self.ham_channels: int = ham_channels - self.ham_kwargs: dict[str, Any] = ham_kwargs if ham_kwargs is not None else {} + self.ham_channels = ham_channels + self.ham_kwargs = ( + ham_kwargs if ham_kwargs is not None else {"md_r": 16, "md_s": 1, "eval_steps": 7, "train_steps": 6} + ) self.squeeze = Conv2dModule( sum(self.in_channels), @@ -105,7 +134,7 @@ def __init__( activation_callable=self.activation_callable, ) - self.hamburger = Hamburger(self.ham_channels, ham_kwargs=self.ham_kwargs, **kwargs) + self.hamburger = Hamburger(self.ham_channels, ham_kwargs=self.ham_kwargs, norm_cfg=norm_cfg) self.align = Conv2dModule( self.ham_channels, @@ -148,8 +177,6 @@ def __init__( md_r: int = 64, train_steps: int = 6, eval_steps: int = 7, - inv_t: int = 1, - rand_init: bool = True, ) -> None: """Initialize Non-negative Matrix Factorization (NMF) module. @@ -172,7 +199,6 @@ def __init__( self.train_steps = train_steps self.eval_steps = eval_steps - self.rand_init = rand_init bases = f.normalize(torch.rand((self.s, ham_channels // self.s, self.r))) self.bases = torch.nn.parameter.Parameter(bases, requires_grad=False) self.inv_t = 1 @@ -283,3 +309,33 @@ def compute_coef(self, x: torch.Tensor, bases: torch.Tensor, coef: torch.Tensor) # multiplication update return coef * numerator / (denominator + 1e-6) + + +class LightHamHead: + """LightHamHead factory for segmentation.""" + + HAMHEAD_CFG: ClassVar[dict[str, Any]] = { + "segnext_base": { + "in_channels": [128, 320, 512], + "channels": 512, + "ham_channels": 512, + }, + "segnext_small": { + "in_channels": [128, 320, 512], + "channels": 256, + "ham_channels": 256, + }, + "segnext_tiny": { + "in_channels": [64, 160, 256], + "channels": 256, + "ham_channels": 256, + }, + } + + def __new__(cls, version: str, num_classes: int) -> NNLightHamHead: + """Constructor for FCNHead.""" + if version not in cls.HAMHEAD_CFG: + msg = f"model type '{version}' is not supported" + raise KeyError(msg) + + return NNLightHamHead(**cls.HAMHEAD_CFG[version], num_classes=num_classes) diff --git a/src/otx/algo/segmentation/huggingface_model.py b/src/otx/algo/segmentation/huggingface_model.py index f00a7faceb6..83629896ed8 100644 --- a/src/otx/algo/segmentation/huggingface_model.py +++ b/src/otx/algo/segmentation/huggingface_model.py @@ -87,6 +87,7 @@ def __init__( def _create_model(self) -> nn.Module: model_config, _ = PretrainedConfig.get_config_dict(self.model_name) kwargs = {} + if "image_size" in model_config: kwargs["image_size"] = self.input_size[-1] @@ -148,7 +149,7 @@ def _exporter(self) -> OTXModelExporter: return OTXNativeModelExporter( task_level_export_parameters=self._export_parameters, - input_size=self.input_size, + input_size=(1, 3, *self.input_size), mean=image_mean, std=image_std, resize_mode="standard", diff --git a/src/otx/algo/segmentation/litehrnet.py b/src/otx/algo/segmentation/litehrnet.py index 6846ebd0f78..c2ad27296b4 100644 --- a/src/otx/algo/segmentation/litehrnet.py +++ b/src/otx/algo/segmentation/litehrnet.py @@ -7,568 +7,42 @@ from typing import TYPE_CHECKING, Any, ClassVar -import torch from torch.onnx import OperatorExportTypes -from otx.algo.segmentation.backbones import LiteHRNet +from otx.algo.segmentation.backbones import LiteHRNetBackbone from otx.algo.segmentation.heads import FCNHead -from otx.algo.segmentation.segmentors import BaseSegmModel, MeanTeacher +from otx.algo.segmentation.losses import CrossEntropyLossWithIgnore +from otx.algo.segmentation.segmentors import BaseSegmModel from otx.algo.utils.support_otx_v1 import OTXv1Helper -from otx.core.data.entity.segmentation import SegBatchDataEntity from otx.core.exporter.base import OTXModelExporter from otx.core.exporter.native import OTXNativeModelExporter -from otx.core.metrics.dice import SegmCallable -from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable -from otx.core.model.segmentation import TorchVisionCompatibleModel +from otx.core.model.segmentation import OTXSegmentationModel if TYPE_CHECKING: - from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable from torch import nn - from otx.core.metrics import MetricCallable - from otx.core.schedulers import LRSchedulerListCallable - from otx.core.types.label import LabelInfoTypes - -class LiteHRNetS(BaseSegmModel): - """LiteHRNetS Model.""" - - default_backbone_configuration: ClassVar[dict[str, Any]] = { - "norm_cfg": {"type": "BN", "requires_grad": True}, - "norm_eval": False, - "extra": { - "stem": { - "stem_channels": 32, - "out_channels": 32, - "expand_ratio": 1, - "strides": [2, 2], - "extra_stride": True, - "input_norm": False, - }, - "num_stages": 2, - "stages_spec": { - "num_modules": [4, 4], - "num_branches": [2, 3], - "num_blocks": [2, 2], - "module_type": ["LITE", "LITE"], - "with_fuse": [True, True], - "reduce_ratios": [8, 8], - "num_channels": [[60, 120], [60, 120, 240]], - }, - "out_modules": { - "conv": {"enable": False, "channels": 160}, - "position_att": {"enable": False, "key_channels": 64, "value_channels": 240, "psp_size": [1, 3, 6, 8]}, - "local_att": {"enable": False}, - }, - "out_aggregator": {"enable": False}, - "add_input": False, - }, - "pretrained_weights": "https://storage.openvinotoolkit.org/repositories/openvino_training_extensions/models/custom_semantic_segmentation/litehrnetsv2_imagenet1k_rsc.pth", - } - default_decode_head_configuration: ClassVar[dict[str, Any]] = { - "norm_cfg": {"type": "BN", "requires_grad": True}, - "in_channels": [60, 120, 240], - "in_index": [0, 1, 2], - "input_transform": "multiple_select", - "channels": 60, - "kernel_size": 1, - "num_convs": 1, - "concat_input": False, - "enable_aggregator": True, - "aggregator_merge_norm": "None", - "aggregator_use_concat": False, - "dropout_ratio": -1, - "align_corners": False, - } - - @property - def ignore_scope(self) -> dict[str, str | dict[str, list[str]]]: - """The ignored scope for LiteHRNetS.""" - ignored_scope_names = [ - "__module.model.backbone.stage0.0.layers.0.cross_resolution_weighting/aten::mul/Multiply", - "__module.model.backbone.stage0.0.layers.0.cross_resolution_weighting/aten::mul/Multiply_1", - "__module.model.backbone.stage0.0.layers.1.cross_resolution_weighting/aten::mul/Multiply", - "__module.model.backbone.stage0.0.layers.1.cross_resolution_weighting/aten::mul/Multiply_1", - "__module.model.backbone.stage0.0/aten::add_/Add_1", - "__module.model.backbone.stage0.1.layers.0.cross_resolution_weighting/aten::mul/Multiply", - "__module.model.backbone.stage0.1.layers.0.cross_resolution_weighting/aten::mul/Multiply_1", - "__module.model.backbone.stage0.1.layers.1.cross_resolution_weighting/aten::mul/Multiply", - "__module.model.backbone.stage0.1.layers.1.cross_resolution_weighting/aten::mul/Multiply_1", - "__module.model.backbone.stage0.1/aten::add_/Add_1", - "__module.model.backbone.stage0.2.layers.0.cross_resolution_weighting/aten::mul/Multiply", - "__module.model.backbone.stage0.2.layers.0.cross_resolution_weighting/aten::mul/Multiply_1", - "__module.model.backbone.stage0.2.layers.1.cross_resolution_weighting/aten::mul/Multiply", - "__module.model.backbone.stage0.2.layers.1.cross_resolution_weighting/aten::mul/Multiply_1", - "__module.model.backbone.stage0.2/aten::add_/Add_1", - "__module.model.backbone.stage0.3.layers.0.cross_resolution_weighting/aten::mul/Multiply", - "__module.model.backbone.stage0.3.layers.0.cross_resolution_weighting/aten::mul/Multiply_1", - "__module.model.backbone.stage0.3.layers.1.cross_resolution_weighting/aten::mul/Multiply", - "__module.model.backbone.stage0.3.layers.1.cross_resolution_weighting/aten::mul/Multiply_1", - "__module.model.backbone.stage0.3/aten::add_/Add_1", - "__module.model.backbone.stage1.0.layers.0.cross_resolution_weighting/aten::mul/Multiply", - "__module.model.backbone.stage1.0.layers.0.cross_resolution_weighting/aten::mul/Multiply_1", - "__module.model.backbone.stage1.0.layers.0.cross_resolution_weighting/aten::mul/Multiply_2", - "__module.model.backbone.stage1.0.layers.1.cross_resolution_weighting/aten::mul/Multiply", - "__module.model.backbone.stage1.0.layers.1.cross_resolution_weighting/aten::mul/Multiply_1", - "__module.model.backbone.stage1.0/aten::add_/Add_1", - "__module.model.backbone.stage1.0.layers.1.cross_resolution_weighting/aten::mul/Multiply_2", - "__module.model.backbone.stage1.0/aten::add_/Add_2", - "__module.model.backbone.stage1.0/aten::add_/Add_5", - "__module.model.backbone.stage1.1.layers.0.cross_resolution_weighting/aten::mul/Multiply", - "__module.model.backbone.stage1.1.layers.0.cross_resolution_weighting/aten::mul/Multiply_1", - "__module.model.backbone.stage1.1.layers.0.cross_resolution_weighting/aten::mul/Multiply_2", - "__module.model.backbone.stage1.1.layers.1.cross_resolution_weighting/aten::mul/Multiply", - "__module.model.backbone.stage1.1.layers.1.cross_resolution_weighting/aten::mul/Multiply_1", - "__module.model.backbone.stage1.1/aten::add_/Add_1", - "__module.model.backbone.stage1.1.layers.1.cross_resolution_weighting/aten::mul/Multiply_2", - "__module.model.backbone.stage1.1/aten::add_/Add_2", - "__module.model.backbone.stage1.1/aten::add_/Add_5", - "__module.model.backbone.stage1.2.layers.0.cross_resolution_weighting/aten::mul/Multiply", - "__module.model.backbone.stage1.2.layers.0.cross_resolution_weighting/aten::mul/Multiply_1", - "__module.model.backbone.stage1.2.layers.0.cross_resolution_weighting/aten::mul/Multiply_2", - "__module.model.backbone.stage1.2.layers.1.cross_resolution_weighting/aten::mul/Multiply", - "__module.model.backbone.stage1.2.layers.1.cross_resolution_weighting/aten::mul/Multiply_1", - "__module.model.backbone.stage1.2/aten::add_/Add_1", - "__module.model.backbone.stage1.2.layers.1.cross_resolution_weighting/aten::mul/Multiply_2", - "__module.model.backbone.stage1.2/aten::add_/Add_2", - "__module.model.backbone.stage1.2/aten::add_/Add_5", - "__module.model.backbone.stage1.3.layers.0.cross_resolution_weighting/aten::mul/Multiply", - "__module.model.backbone.stage1.3.layers.0.cross_resolution_weighting/aten::mul/Multiply_1", - "__module.model.backbone.stage1.3.layers.0.cross_resolution_weighting/aten::mul/Multiply_2", - "__module.model.backbone.stage1.3.layers.1.cross_resolution_weighting/aten::mul/Multiply", - "__module.model.backbone.stage1.3.layers.1.cross_resolution_weighting/aten::mul/Multiply_1", - "__module.model.backbone.stage1.3/aten::add_/Add_1", - "__module.model.backbone.stage1.3.layers.1.cross_resolution_weighting/aten::mul/Multiply_2", - "__module.model.backbone.stage1.3/aten::add_/Add_2", - "__module.model.backbone.stage1.3/aten::add_/Add_5", - "__module.model.decode_head.aggregator/aten::add/Add", - "__module.model.decode_head.aggregator/aten::add/Add_1", - ] - - return { - "ignored_scope": { - "names": ignored_scope_names, - }, - "preset": "mixed", - } - - -class LiteHRNet18(BaseSegmModel): - """LiteHRNet18 Model.""" - - default_backbone_configuration: ClassVar[dict[str, Any]] = { - "norm_eval": False, - "extra": { - "stem": { - "stem_channels": 32, - "out_channels": 32, - "expand_ratio": 1, - "strides": [2, 2], - "extra_stride": False, - "input_norm": False, - }, - "num_stages": 3, - "stages_spec": { - "num_modules": [2, 4, 2], - "num_branches": [2, 3, 4], - "num_blocks": [2, 2, 2], - "module_type": ["LITE", "LITE", "LITE"], - "with_fuse": [True, True, True], - "reduce_ratios": [8, 8, 8], - "num_channels": [[40, 80], [40, 80, 160], [40, 80, 160, 320]], - }, - "out_modules": { - "conv": {"enable": False, "channels": 320}, - "position_att": {"enable": False, "key_channels": 128, "value_channels": 320, "psp_size": [1, 3, 6, 8]}, - "local_att": {"enable": False}, - }, - "out_aggregator": {"enable": False}, - "add_input": False, - }, - "pretrained_weights": "https://storage.openvinotoolkit.org/repositories/openvino_training_extensions/models/custom_semantic_segmentation/litehrnet18_imagenet1k_rsc.pth", - } - default_decode_head_configuration: ClassVar[dict[str, Any]] = { - "norm_cfg": {"type": "BN", "requires_grad": True}, - "in_channels": [40, 80, 160, 320], - "in_index": [0, 1, 2, 3], - "input_transform": "multiple_select", - "channels": 40, - "enable_aggregator": True, - "kernel_size": 1, - "num_convs": 1, - "concat_input": False, - "dropout_ratio": -1, - "align_corners": False, - } - - @property - def ignore_scope(self) -> dict[str, str | dict[str, list[str]]]: - """The ignored scope of the LiteHRNet18 model.""" - ignored_scope_names = [ - "__module.model.backbone.stage0.0.layers.0.cross_resolution_weighting/aten::mul/Multiply", - "__module.model.backbone.stage0.0.layers.0.cross_resolution_weighting/aten::mul/Multiply_1", - "__module.model.backbone.stage0.0.layers.1.cross_resolution_weighting/aten::mul/Multiply", - "__module.model.backbone.stage0.0.layers.1.cross_resolution_weighting/aten::mul/Multiply_1", - "__module.model.backbone.stage0.0/aten::add_/Add_1", - "__module.model.backbone.stage0.1.layers.0.cross_resolution_weighting/aten::mul/Multiply", - "__module.model.backbone.stage0.1.layers.0.cross_resolution_weighting/aten::mul/Multiply_1", - "__module.model.backbone.stage0.1.layers.1.cross_resolution_weighting/aten::mul/Multiply", - "__module.model.backbone.stage0.1.layers.1.cross_resolution_weighting/aten::mul/Multiply_1", - "__module.model.backbone.stage0.1/aten::add_/Add_1", - "__module.model.backbone.stage1.0.layers.0.cross_resolution_weighting/aten::mul/Multiply", - "__module.model.backbone.stage1.0.layers.0.cross_resolution_weighting/aten::mul/Multiply_1", - "__module.model.backbone.stage1.0.layers.0.cross_resolution_weighting/aten::mul/Multiply_2", - "__module.model.backbone.stage1.0.layers.1.cross_resolution_weighting/aten::mul/Multiply", - "__module.model.backbone.stage1.0.layers.1.cross_resolution_weighting/aten::mul/Multiply_1", - "__module.model.backbone.stage1.0/aten::add_/Add_1", - "__module.model.backbone.stage1.0.layers.1.cross_resolution_weighting/aten::mul/Multiply_2", - "__module.model.backbone.stage1.0/aten::add_/Add_2", - "__module.model.backbone.stage1.0/aten::add_/Add_5", - "__module.model.backbone.stage1.1.layers.0.cross_resolution_weighting/aten::mul/Multiply", - "__module.model.backbone.stage1.1.layers.0.cross_resolution_weighting/aten::mul/Multiply_1", - "__module.model.backbone.stage1.1.layers.0.cross_resolution_weighting/aten::mul/Multiply_2", - "__module.model.backbone.stage1.1.layers.1.cross_resolution_weighting/aten::mul/Multiply", - "__module.model.backbone.stage1.1.layers.1.cross_resolution_weighting/aten::mul/Multiply_1", - "__module.model.backbone.stage1.1/aten::add_/Add_1", - "__module.model.backbone.stage1.1.layers.1.cross_resolution_weighting/aten::mul/Multiply_2", - "__module.model.backbone.stage1.1/aten::add_/Add_2", - "__module.model.backbone.stage1.1/aten::add_/Add_5", - "__module.model.backbone.stage1.2.layers.0.cross_resolution_weighting/aten::mul/Multiply", - "__module.model.backbone.stage1.2.layers.0.cross_resolution_weighting/aten::mul/Multiply_1", - "__module.model.backbone.stage1.2.layers.0.cross_resolution_weighting/aten::mul/Multiply_2", - "__module.model.backbone.stage1.2.layers.1.cross_resolution_weighting/aten::mul/Multiply", - "__module.model.backbone.stage1.2.layers.1.cross_resolution_weighting/aten::mul/Multiply_1", - "__module.model.backbone.stage1.2/aten::add_/Add_1", - "__module.model.backbone.stage1.2.layers.1.cross_resolution_weighting/aten::mul/Multiply_2", - "__module.model.backbone.stage1.2/aten::add_/Add_2", - "__module.model.backbone.stage1.2/aten::add_/Add_5", - "__module.model.backbone.stage1.3.layers.0.cross_resolution_weighting/aten::mul/Multiply", - "__module.model.backbone.stage1.3.layers.0.cross_resolution_weighting/aten::mul/Multiply_1", - "__module.model.backbone.stage1.3.layers.0.cross_resolution_weighting/aten::mul/Multiply_2", - "__module.model.backbone.stage1.3.layers.1.cross_resolution_weighting/aten::mul/Multiply", - "__module.model.backbone.stage1.3.layers.1.cross_resolution_weighting/aten::mul/Multiply_1", - "__module.model.backbone.stage1.3/aten::add_/Add_1", - "__module.model.backbone.stage1.3.layers.1.cross_resolution_weighting/aten::mul/Multiply_2", - "__module.model.backbone.stage1.3/aten::add_/Add_2", - "__module.model.backbone.stage1.3/aten::add_/Add_5", - "__module.model.backbone.stage2.0.layers.0.cross_resolution_weighting/aten::mul/Multiply", - "__module.model.backbone.stage2.0.layers.0.cross_resolution_weighting/aten::mul/Multiply_1", - "__module.model.backbone.stage2.0.layers.0.cross_resolution_weighting/aten::mul/Multiply_2", - "__module.model.backbone.stage2.0.layers.0.cross_resolution_weighting/aten::mul/Multiply_3", - "__module.model.backbone.stage2.0.layers.1.cross_resolution_weighting/aten::mul/Multiply", - "__module.model.backbone.stage2.0.layers.1.cross_resolution_weighting/aten::mul/Multiply_1", - "__module.model.backbone.stage2.0/aten::add_/Add_1", - "__module.model.backbone.stage2.0.layers.1.cross_resolution_weighting/aten::mul/Multiply_2", - "__module.model.backbone.stage2.0/aten::add_/Add_2", - "__module.model.backbone.stage2.0.layers.1.cross_resolution_weighting/aten::mul/Multiply_3", - "__module.model.backbone.stage2.0/aten::add_/Add_3", - "__module.model.backbone.stage2.0/aten::add_/Add_6", - "__module.model.backbone.stage2.0/aten::add_/Add_7", - "__module.model.backbone.stage2.0/aten::add_/Add_11", - "__module.model.backbone.stage2.1.layers.0.cross_resolution_weighting/aten::mul/Multiply", - "__module.model.backbone.stage2.1.layers.0.cross_resolution_weighting/aten::mul/Multiply_1", - "__module.model.backbone.stage2.1.layers.0.cross_resolution_weighting/aten::mul/Multiply_2", - "__module.model.backbone.stage2.1.layers.0.cross_resolution_weighting/aten::mul/Multiply_3", - "__module.model.backbone.stage2.1.layers.1.cross_resolution_weighting/aten::mul/Multiply", - "__module.model.backbone.stage2.1.layers.1.cross_resolution_weighting/aten::mul/Multiply_1", - "__module.model.backbone.stage2.1/aten::add_/Add_1", - "__module.model.backbone.stage2.1.layers.1.cross_resolution_weighting/aten::mul/Multiply_2", - "__module.model.backbone.stage2.1/aten::add_/Add_2", - "__module.model.backbone.stage2.1.layers.1.cross_resolution_weighting/aten::mul/Multiply_3", - "__module.model.backbone.stage2.1/aten::add_/Add_3", - "__module.model.backbone.stage2.1/aten::add_/Add_6", - "__module.model.backbone.stage2.1/aten::add_/Add_7", - "__module.model.backbone.stage2.1/aten::add_/Add_11", - "__module.model.decode_head.aggregator/aten::add/Add", - "__module.model.decode_head.aggregator/aten::add/Add_1", - "__module.model.decode_head.aggregator/aten::add/Add_2", - "__module.model.backbone.stage2.1/aten::add_/Add", - ] - - return { - "ignored_scope": { - "patterns": ["__module.model.backbone/*"], - "names": ignored_scope_names, - }, - "preset": "mixed", - } - - -class LiteHRNetX(BaseSegmModel): - """LiteHRNetX Model.""" - - default_backbone_configuration: ClassVar[dict[str, Any]] = { - "norm_cfg": {"type": "BN", "requires_grad": True}, - "norm_eval": False, - "extra": { - "stem": { - "stem_channels": 60, - "out_channels": 60, - "expand_ratio": 1, - "strides": [2, 1], - "extra_stride": False, - "input_norm": False, - }, - "num_stages": 4, - "stages_spec": { - "weighting_module_version": "v1", - "num_modules": [2, 4, 4, 2], - "num_branches": [2, 3, 4, 5], - "num_blocks": [2, 2, 2, 2], - "module_type": ["LITE", "LITE", "LITE", "LITE"], - "with_fuse": [True, True, True, True], - "reduce_ratios": [2, 4, 8, 8], - "num_channels": [[18, 60], [18, 60, 80], [18, 60, 80, 160], [18, 60, 80, 160, 320]], - }, - "out_modules": { - "conv": {"enable": False, "channels": 320}, - "position_att": {"enable": False, "key_channels": 128, "value_channels": 320, "psp_size": [1, 3, 6, 8]}, - "local_att": {"enable": False}, - }, - "out_aggregator": {"enable": False}, - "add_input": False, - }, - "pretrained_weights": "https://storage.openvinotoolkit.org/repositories/openvino_training_extensions/models/custom_semantic_segmentation/litehrnetxv3_imagenet1k_rsc.pth", - } - default_decode_head_configuration: ClassVar[dict[str, Any]] = { - "norm_cfg": {"type": "BN", "requires_grad": True}, - "in_channels": [18, 60, 80, 160, 320], - "in_index": [0, 1, 2, 3, 4], - "input_transform": "multiple_select", - "channels": 60, - "kernel_size": 1, - "num_convs": 1, - "concat_input": False, - "dropout_ratio": -1, - "enable_aggregator": True, - "aggregator_min_channels": 60, - "aggregator_merge_norm": "None", - "aggregator_use_concat": False, - "align_corners": False, - } - - @property - def ignore_scope(self) -> dict[str, str | dict[str, list[str]]]: - """The ignored scope of the LiteHRNetX model.""" - ignored_scope_names = [ - "__module.model.backbone.stage0.0.layers.0.cross_resolution_weighting/aten::mul/Multiply", - "__module.model.backbone.stage0.0.layers.0.cross_resolution_weighting/aten::mul/Multiply_1", - "__module.model.backbone.stage0.0.layers.1.cross_resolution_weighting/aten::mul/Multiply", - "__module.model.backbone.stage0.0.layers.1.cross_resolution_weighting/aten::mul/Multiply_1", - "__module.model.backbone.stage0.0/aten::add_/Add_1", - "__module.model.backbone.stage0.1.layers.0.cross_resolution_weighting/aten::mul/Multiply", - "__module.model.backbone.stage0.1.layers.0.cross_resolution_weighting/aten::mul/Multiply_1", - "__module.model.backbone.stage0.1.layers.1.cross_resolution_weighting/aten::mul/Multiply", - "__module.model.backbone.stage0.1.layers.1.cross_resolution_weighting/aten::mul/Multiply_1", - "__module.model.backbone.stage0.1/aten::add_/Add_1", - "__module.model.backbone.stage1.0.layers.0.cross_resolution_weighting/aten::mul/Multiply", - "__module.model.backbone.stage1.0.layers.0.cross_resolution_weighting/aten::mul/Multiply_1", - "__module.model.backbone.stage1.0.layers.0.cross_resolution_weighting/aten::mul/Multiply_2", - "__module.model.backbone.stage1.0.layers.1.cross_resolution_weighting/aten::mul/Multiply", - "__module.model.backbone.stage1.0.layers.1.cross_resolution_weighting/aten::mul/Multiply_1", - "__module.model.backbone.stage1.0/aten::add_/Add_1", - "__module.model.backbone.stage1.0.layers.1.cross_resolution_weighting/aten::mul/Multiply_2", - "__module.model.backbone.stage1.0/aten::add_/Add_2", - "__module.model.backbone.stage1.0/aten::add_/Add_5", - "__module.model.backbone.stage1.1.layers.0.cross_resolution_weighting/aten::mul/Multiply", - "__module.model.backbone.stage1.1.layers.0.cross_resolution_weighting/aten::mul/Multiply_1", - "__module.model.backbone.stage1.1.layers.0.cross_resolution_weighting/aten::mul/Multiply_2", - "__module.model.backbone.stage1.1.layers.1.cross_resolution_weighting/aten::mul/Multiply", - "__module.model.backbone.stage1.1.layers.1.cross_resolution_weighting/aten::mul/Multiply_1", - "__module.model.backbone.stage1.1/aten::add_/Add_1", - "__module.model.backbone.stage1.1.layers.1.cross_resolution_weighting/aten::mul/Multiply_2", - "__module.model.backbone.stage1.1/aten::add_/Add_2", - "__module.model.backbone.stage1.1/aten::add_/Add_5", - "__module.model.backbone.stage1.2.layers.0.cross_resolution_weighting/aten::mul/Multiply", - "__module.model.backbone.stage1.2.layers.0.cross_resolution_weighting/aten::mul/Multiply_1", - "__module.model.backbone.stage1.2.layers.0.cross_resolution_weighting/aten::mul/Multiply_2", - "__module.model.backbone.stage1.2.layers.1.cross_resolution_weighting/aten::mul/Multiply", - "__module.model.backbone.stage1.2.layers.1.cross_resolution_weighting/aten::mul/Multiply_1", - "__module.model.backbone.stage1.2/aten::add_/Add_1", - "__module.model.backbone.stage1.2.layers.1.cross_resolution_weighting/aten::mul/Multiply_2", - "__module.model.backbone.stage1.2/aten::add_/Add_2", - "__module.model.backbone.stage1.2/aten::add_/Add_5", - "__module.model.backbone.stage1.3.layers.0.cross_resolution_weighting/aten::mul/Multiply", - "__module.model.backbone.stage1.3.layers.0.cross_resolution_weighting/aten::mul/Multiply_1", - "__module.model.backbone.stage1.3.layers.0.cross_resolution_weighting/aten::mul/Multiply_2", - "__module.model.backbone.stage1.3.layers.1.cross_resolution_weighting/aten::mul/Multiply", - "__module.model.backbone.stage1.3.layers.1.cross_resolution_weighting/aten::mul/Multiply_1", - "__module.model.backbone.stage1.3/aten::add_/Add_1", - "__module.model.backbone.stage1.3.layers.1.cross_resolution_weighting/aten::mul/Multiply_2", - "__module.model.backbone.stage1.3/aten::add_/Add_2", - "__module.model.backbone.stage1.3/aten::add_/Add_5", - "__module.model.backbone.stage2.0.layers.0.cross_resolution_weighting/aten::mul/Multiply", - "__module.model.backbone.stage2.0.layers.0.cross_resolution_weighting/aten::mul/Multiply_1", - "__module.model.backbone.stage2.0.layers.0.cross_resolution_weighting/aten::mul/Multiply_2", - "__module.model.backbone.stage2.0.layers.0.cross_resolution_weighting/aten::mul/Multiply_3", - "__module.model.backbone.stage2.0.layers.1.cross_resolution_weighting/aten::mul/Multiply", - "__module.model.backbone.stage2.0.layers.1.cross_resolution_weighting/aten::mul/Multiply_1", - "__module.model.backbone.stage2.0/aten::add_/Add_1", - "__module.model.backbone.stage2.0.layers.1.cross_resolution_weighting/aten::mul/Multiply_2", - "__module.model.backbone.stage2.0/aten::add_/Add_2", - "__module.model.backbone.stage2.0.layers.1.cross_resolution_weighting/aten::mul/Multiply_3", - "__module.model.backbone.stage2.0/aten::add_/Add_3", - "__module.model.backbone.stage2.0/aten::add_/Add_6", - "__module.model.backbone.stage2.0/aten::add_/Add_7", - "__module.model.backbone.stage2.0/aten::add_/Add_11", - "__module.model.backbone.stage2.1.layers.0.cross_resolution_weighting/aten::mul/Multiply", - "__module.model.backbone.stage2.1.layers.0.cross_resolution_weighting/aten::mul/Multiply_1", - "__module.model.backbone.stage2.1.layers.0.cross_resolution_weighting/aten::mul/Multiply_2", - "__module.model.backbone.stage2.1.layers.0.cross_resolution_weighting/aten::mul/Multiply_3", - "__module.model.backbone.stage2.1.layers.1.cross_resolution_weighting/aten::mul/Multiply", - "__module.model.backbone.stage2.1.layers.1.cross_resolution_weighting/aten::mul/Multiply_1", - "__module.model.backbone.stage2.1/aten::add_/Add_1", - "__module.model.backbone.stage2.1.layers.1.cross_resolution_weighting/aten::mul/Multiply_2", - "__module.model.backbone.stage2.1/aten::add_/Add_2", - "__module.model.backbone.stage2.1.layers.1.cross_resolution_weighting/aten::mul/Multiply_3", - "__module.model.backbone.stage2.1/aten::add_/Add_3", - "__module.model.backbone.stage2.1/aten::add_/Add_6", - "__module.model.backbone.stage2.1/aten::add_/Add_7", - "__module.model.backbone.stage2.1/aten::add_/Add_11", - "__module.model.backbone.stage2.2.layers.0.cross_resolution_weighting/aten::mul/Multiply", - "__module.model.backbone.stage2.2.layers.0.cross_resolution_weighting/aten::mul/Multiply_1", - "__module.model.backbone.stage2.2.layers.0.cross_resolution_weighting/aten::mul/Multiply_2", - "__module.model.backbone.stage2.2.layers.0.cross_resolution_weighting/aten::mul/Multiply_3", - "__module.model.backbone.stage2.2.layers.1.cross_resolution_weighting/aten::mul/Multiply", - "__module.model.backbone.stage2.2.layers.1.cross_resolution_weighting/aten::mul/Multiply_1", - "__module.model.backbone.stage2.2/aten::add_/Add_1", - "__module.model.backbone.stage2.2.layers.1.cross_resolution_weighting/aten::mul/Multiply_2", - "__module.model.backbone.stage2.2/aten::add_/Add_2", - "__module.model.backbone.stage2.2.layers.1.cross_resolution_weighting/aten::mul/Multiply_3", - "__module.model.backbone.stage2.2/aten::add_/Add_3", - "__module.model.backbone.stage2.2/aten::add_/Add_6", - "__module.model.backbone.stage2.2/aten::add_/Add_7", - "__module.model.backbone.stage2.2/aten::add_/Add_11", - "__module.model.backbone.stage2.3.layers.0.cross_resolution_weighting/aten::mul/Multiply", - "__module.model.backbone.stage2.3.layers.0.cross_resolution_weighting/aten::mul/Multiply_1", - "__module.model.backbone.stage2.3.layers.0.cross_resolution_weighting/aten::mul/Multiply_2", - "__module.model.backbone.stage2.3.layers.0.cross_resolution_weighting/aten::mul/Multiply_3", - "__module.model.backbone.stage2.3.layers.1.cross_resolution_weighting/aten::mul/Multiply", - "__module.model.backbone.stage2.3.layers.1.cross_resolution_weighting/aten::mul/Multiply_1", - "__module.model.backbone.stage2.3/aten::add_/Add_1", - "__module.model.backbone.stage2.3.layers.1.cross_resolution_weighting/aten::mul/Multiply_2", - "__module.model.backbone.stage2.3/aten::add_/Add_2", - "__module.model.backbone.stage2.3.layers.1.cross_resolution_weighting/aten::mul/Multiply_3", - "__module.model.backbone.stage2.3/aten::add_/Add_3", - "__module.model.backbone.stage2.3/aten::add_/Add_6", - "__module.model.backbone.stage2.3/aten::add_/Add_7", - "__module.model.backbone.stage2.3/aten::add_/Add_11", - "__module.model.backbone.stage3.0.layers.0.cross_resolution_weighting/aten::mul/Multiply", - "__module.model.backbone.stage3.0.layers.0.cross_resolution_weighting/aten::mul/Multiply_1", - "__module.model.backbone.stage3.0.layers.0.cross_resolution_weighting/aten::mul/Multiply_2", - "__module.model.backbone.stage3.0.layers.0.cross_resolution_weighting/aten::mul/Multiply_3", - "__module.model.backbone.stage3.0.layers.0.cross_resolution_weighting/aten::mul/Multiply_4", - "__module.model.backbone.stage3.0.layers.1.cross_resolution_weighting/aten::mul/Multiply", - "__module.model.backbone.stage3.0.layers.1.cross_resolution_weighting/aten::mul/Multiply_1", - "__module.model.backbone.stage3.0/aten::add_/Add_1", - "__module.model.backbone.stage3.0.layers.1.cross_resolution_weighting/aten::mul/Multiply_2", - "__module.model.backbone.stage3.0/aten::add_/Add_2", - "__module.model.backbone.stage3.0.layers.1.cross_resolution_weighting/aten::mul/Multiply_3", - "__module.model.backbone.stage3.0/aten::add_/Add_3", - "__module.model.backbone.stage3.0.layers.1.cross_resolution_weighting/aten::mul/Multiply_4", - "__module.model.backbone.stage3.0/aten::add_/Add_4", - "__module.model.backbone.stage3.0/aten::add_/Add_7", - "__module.model.backbone.stage3.0/aten::add_/Add_8", - "__module.model.backbone.stage3.0/aten::add_/Add_9", - "__module.model.backbone.stage3.0/aten::add_/Add_13", - "__module.model.backbone.stage3.0/aten::add_/Add_14", - "__module.model.backbone.stage3.0/aten::add_/Add_19", - "__module.model.backbone.stage3.1.layers.0.cross_resolution_weighting/aten::mul/Multiply", - "__module.model.backbone.stage3.1.layers.0.cross_resolution_weighting/aten::mul/Multiply_1", - "__module.model.backbone.stage3.1.layers.0.cross_resolution_weighting/aten::mul/Multiply_2", - "__module.model.backbone.stage3.1.layers.0.cross_resolution_weighting/aten::mul/Multiply_3", - "__module.model.backbone.stage3.1.layers.0.cross_resolution_weighting/aten::mul/Multiply_4", - "__module.model.backbone.stage3.1.layers.1.cross_resolution_weighting/aten::mul/Multiply", - "__module.model.backbone.stage3.1.layers.1.cross_resolution_weighting/aten::mul/Multiply_1", - "__module.model.backbone.stage3.1/aten::add_/Add_1", - "__module.model.backbone.stage3.1.layers.1.cross_resolution_weighting/aten::mul/Multiply_2", - "__module.model.backbone.stage3.1/aten::add_/Add_2", - "__module.model.backbone.stage3.1.layers.1.cross_resolution_weighting/aten::mul/Multiply_3", - "__module.model.backbone.stage3.1/aten::add_/Add_3", - "__module.model.backbone.stage3.1.layers.1.cross_resolution_weighting/aten::mul/Multiply_4", - "__module.model.backbone.stage3.1/aten::add_/Add_4", - "__module.model.backbone.stage3.1/aten::add_/Add_7", - "__module.model.backbone.stage3.1/aten::add_/Add_8", - "__module.model.backbone.stage3.1/aten::add_/Add_9", - "__module.model.backbone.stage3.1/aten::add_/Add_13", - "__module.model.backbone.stage3.1/aten::add_/Add_14", - "__module.model.backbone.stage3.1/aten::add_/Add_19", - "__module.model.backbone.stage0.0/aten::add_/Add", - "__module.model.backbone.stage0.1/aten::add_/Add", - "__module.model.backbone.stage1.0/aten::add_/Add", - "__module.model.backbone.stage1.1/aten::add_/Add", - "__module.model.backbone.stage1.2/aten::add_/Add", - "__module.model.backbone.stage1.3/aten::add_/Add", - "__module.model.backbone.stage2.0/aten::add_/Add", - "__module.model.backbone.stage2.1/aten::add_/Add", - "__module.model.backbone.stage2.2/aten::add_/Add", - "__module.model.backbone.stage2.3/aten::add_/Add", - "__module.model.backbone.stage3.0/aten::add_/Add", - "__module.model.backbone.stage3.1/aten::add_/Add", - ] - - return { - "ignored_scope": { - "patterns": ["__module.model.decode_head.aggregator/*"], - "names": ignored_scope_names, - }, - "preset": "performance", - } - - -LITEHRNET_VARIANTS = { - "LiteHRNet18": LiteHRNet18, - "LiteHRNetS": LiteHRNetS, - "LiteHRNetX": LiteHRNetX, -} - - -class OTXLiteHRNet(TorchVisionCompatibleModel): +class LiteHRNet(OTXSegmentationModel): """LiteHRNet Model.""" - def __init__( - self, - label_info: LabelInfoTypes, - input_size: tuple[int, int] = (512, 512), - optimizer: OptimizerCallable = DefaultOptimizerCallable, - scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable, - metric: MetricCallable = SegmCallable, # type: ignore[assignment] - torch_compile: bool = False, - backbone_configuration: dict[str, Any] | None = None, - decode_head_configuration: dict[str, Any] | None = None, - criterion_configuration: list[dict[str, Any]] | None = None, - export_image_configuration: dict[str, Any] | None = None, - name_base_model: str = "semantic_segmentation_model", - ): - super().__init__( - label_info=label_info, - input_size=input_size, - optimizer=optimizer, - scheduler=scheduler, - metric=metric, - torch_compile=torch_compile, - backbone_configuration=backbone_configuration, - decode_head_configuration=decode_head_configuration, - criterion_configuration=criterion_configuration, - export_image_configuration=export_image_configuration, - name_base_model=name_base_model, - ) + AVAILABLE_MODEL_VERSIONS: ClassVar[list[str]] = [ + "lite_hrnet_s", + "lite_hrnet_18", + "lite_hrnet_x", + ] - def _create_model(self) -> nn.Module: - litehrnet_model_class = LITEHRNET_VARIANTS[self.name_base_model] - # merge configurations with defaults overriding them - backbone_configuration = litehrnet_model_class.default_backbone_configuration | self.backbone_configuration - decode_head_configuration = ( - litehrnet_model_class.default_decode_head_configuration | self.decode_head_configuration - ) - # initialize backbones - backbone = LiteHRNet(**backbone_configuration) - decode_head = FCNHead(num_classes=self.num_classes, **decode_head_configuration) + def _build_model(self) -> nn.Module: + if self.model_version not in self.AVAILABLE_MODEL_VERSIONS: + msg = f"Model version {self.model_version} is not supported." + raise ValueError(msg) - return litehrnet_model_class( + backbone = LiteHRNetBackbone(self.model_version) + decode_head = FCNHead(self.model_version, num_classes=self.num_classes) + criterion = CrossEntropyLossWithIgnore(ignore_index=self.label_info.ignore_index) # type: ignore[attr-defined] + return BaseSegmModel( backbone=backbone, decode_head=decode_head, - criterion_configuration=self.criterion_configuration, + criterion=criterion, ) def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.model.") -> dict: @@ -578,8 +52,7 @@ def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.model @property def _optimization_config(self) -> dict[str, Any]: """PTQ config for LiteHRNet.""" - # TODO(Kirill): check PTQ without adding the whole backbone to ignored_scope - ignored_scope = self.model.ignore_scope + ignored_scope = self.ignore_scope optim_config = { "advanced_parameters": { "activations_range_estimator_params": { @@ -611,50 +84,250 @@ def _exporter(self) -> OTXModelExporter: output_names=None, ) + @property + def ignore_scope(self) -> dict[str, Any]: + """Get the ignored scope for LiteHRNet.""" + if self.model_version == "large": + return { + "ignored_scope": { + "patterns": ["__module.model.decode_head.aggregator/*"], + "names": [ + "__module.model.backbone.stage0.0.layers.0.cross_resolution_weighting/aten::mul/Multiply", + "__module.model.backbone.stage0.0.layers.0.cross_resolution_weighting/aten::mul/Multiply_1", + "__module.model.backbone.stage0.0.layers.1.cross_resolution_weighting/aten::mul/Multiply", + "__module.model.backbone.stage0.0.layers.1.cross_resolution_weighting/aten::mul/Multiply_1", + "__module.model.backbone.stage0.0/aten::add_/Add_1", + "__module.model.backbone.stage0.1.layers.0.cross_resolution_weighting/aten::mul/Multiply", + "__module.model.backbone.stage0.1.layers.0.cross_resolution_weighting/aten::mul/Multiply_1", + "__module.model.backbone.stage0.1.layers.1.cross_resolution_weighting/aten::mul/Multiply", + "__module.model.backbone.stage0.1.layers.1.cross_resolution_weighting/aten::mul/Multiply_1", + "__module.model.backbone.stage0.1/aten::add_/Add_1", + "__module.model.backbone.stage1.0.layers.0.cross_resolution_weighting/aten::mul/Multiply", + "__module.model.backbone.stage1.0.layers.0.cross_resolution_weighting/aten::mul/Multiply_1", + "__module.model.backbone.stage1.0.layers.0.cross_resolution_weighting/aten::mul/Multiply_2", + "__module.model.backbone.stage1.0.layers.1.cross_resolution_weighting/aten::mul/Multiply", + "__module.model.backbone.stage1.0.layers.1.cross_resolution_weighting/aten::mul/Multiply_1", + "__module.model.backbone.stage1.0/aten::add_/Add_1", + "__module.model.backbone.stage1.0.layers.1.cross_resolution_weighting/aten::mul/Multiply_2", + "__module.model.backbone.stage1.0/aten::add_/Add_2", + "__module.model.backbone.stage1.0/aten::add_/Add_5", + "__module.model.backbone.stage1.1.layers.0.cross_resolution_weighting/aten::mul/Multiply", + "__module.model.backbone.stage1.1.layers.0.cross_resolution_weighting/aten::mul/Multiply_1", + "__module.model.backbone.stage1.1.layers.0.cross_resolution_weighting/aten::mul/Multiply_2", + "__module.model.backbone.stage1.1.layers.1.cross_resolution_weighting/aten::mul/Multiply", + "__module.model.backbone.stage1.1.layers.1.cross_resolution_weighting/aten::mul/Multiply_1", + "__module.model.backbone.stage1.1/aten::add_/Add_1", + "__module.model.backbone.stage1.1.layers.1.cross_resolution_weighting/aten::mul/Multiply_2", + "__module.model.backbone.stage1.1/aten::add_/Add_2", + "__module.model.backbone.stage1.1/aten::add_/Add_5", + "__module.model.backbone.stage1.2.layers.0.cross_resolution_weighting/aten::mul/Multiply", + "__module.model.backbone.stage1.2.layers.0.cross_resolution_weighting/aten::mul/Multiply_1", + "__module.model.backbone.stage1.2.layers.0.cross_resolution_weighting/aten::mul/Multiply_2", + "__module.model.backbone.stage1.2.layers.1.cross_resolution_weighting/aten::mul/Multiply", + "__module.model.backbone.stage1.2.layers.1.cross_resolution_weighting/aten::mul/Multiply_1", + "__module.model.backbone.stage1.2/aten::add_/Add_1", + "__module.model.backbone.stage1.2.layers.1.cross_resolution_weighting/aten::mul/Multiply_2", + "__module.model.backbone.stage1.2/aten::add_/Add_2", + "__module.model.backbone.stage1.2/aten::add_/Add_5", + "__module.model.backbone.stage1.3.layers.0.cross_resolution_weighting/aten::mul/Multiply", + "__module.model.backbone.stage1.3.layers.0.cross_resolution_weighting/aten::mul/Multiply_1", + "__module.model.backbone.stage1.3.layers.0.cross_resolution_weighting/aten::mul/Multiply_2", + "__module.model.backbone.stage1.3.layers.1.cross_resolution_weighting/aten::mul/Multiply", + "__module.model.backbone.stage1.3.layers.1.cross_resolution_weighting/aten::mul/Multiply_1", + "__module.model.backbone.stage1.3/aten::add_/Add_1", + "__module.model.backbone.stage1.3.layers.1.cross_resolution_weighting/aten::mul/Multiply_2", + "__module.model.backbone.stage1.3/aten::add_/Add_2", + "__module.model.backbone.stage1.3/aten::add_/Add_5", + "__module.model.backbone.stage2.0.layers.0.cross_resolution_weighting/aten::mul/Multiply", + "__module.model.backbone.stage2.0.layers.0.cross_resolution_weighting/aten::mul/Multiply_1", + "__module.model.backbone.stage2.0.layers.0.cross_resolution_weighting/aten::mul/Multiply_2", + "__module.model.backbone.stage2.0.layers.0.cross_resolution_weighting/aten::mul/Multiply_3", + "__module.model.backbone.stage2.0.layers.1.cross_resolution_weighting/aten::mul/Multiply", + "__module.model.backbone.stage2.0.layers.1.cross_resolution_weighting/aten::mul/Multiply_1", + "__module.model.backbone.stage2.0/aten::add_/Add_1", + "__module.model.backbone.stage2.0.layers.1.cross_resolution_weighting/aten::mul/Multiply_2", + "__module.model.backbone.stage2.0/aten::add_/Add_2", + "__module.model.backbone.stage2.0.layers.1.cross_resolution_weighting/aten::mul/Multiply_3", + "__module.model.backbone.stage2.0/aten::add_/Add_3", + "__module.model.backbone.stage2.0/aten::add_/Add_6", + "__module.model.backbone.stage2.0/aten::add_/Add_7", + "__module.model.backbone.stage2.0/aten::add_/Add_11", + "__module.model.backbone.stage2.1.layers.0.cross_resolution_weighting/aten::mul/Multiply", + "__module.model.backbone.stage2.1.layers.0.cross_resolution_weighting/aten::mul/Multiply_1", + "__module.model.backbone.stage2.1.layers.0.cross_resolution_weighting/aten::mul/Multiply_2", + "__module.model.backbone.stage2.1.layers.0.cross_resolution_weighting/aten::mul/Multiply_3", + "__module.model.backbone.stage2.1.layers.1.cross_resolution_weighting/aten::mul/Multiply", + "__module.model.backbone.stage2.1.layers.1.cross_resolution_weighting/aten::mul/Multiply_1", + "__module.model.backbone.stage2.1/aten::add_/Add_1", + "__module.model.backbone.stage2.1.layers.1.cross_resolution_weighting/aten::mul/Multiply_2", + "__module.model.backbone.stage2.1/aten::add_/Add_2", + "__module.model.backbone.stage2.1.layers.1.cross_resolution_weighting/aten::mul/Multiply_3", + "__module.model.backbone.stage2.1/aten::add_/Add_3", + "__module.model.backbone.stage2.1/aten::add_/Add_6", + "__module.model.backbone.stage2.1/aten::add_/Add_7", + "__module.model.backbone.stage2.1/aten::add_/Add_11", + "__module.model.decode_head.aggregator/aten::add/Add", + "__module.model.decode_head.aggregator/aten::add/Add_1", + "__module.model.decode_head.aggregator/aten::add/Add_2", + "__module.model.backbone.stage2.1/aten::add_/Add", + ], + }, + "preset": "performance", + } + + if self.model_version == "medium": + return { + "ignored_scope": { + "patterns": ["__module.model.backbone/*"], + "names": [ + "__module.model.backbone.stage0.0.layers.0.cross_resolution_weighting/aten::mul/Multiply", + "__module.model.backbone.stage0.0.layers.0.cross_resolution_weighting/aten::mul/Multiply_1", + "__module.model.backbone.stage0.0.layers.1.cross_resolution_weighting/aten::mul/Multiply", + "__module.model.backbone.stage0.0.layers.1.cross_resolution_weighting/aten::mul/Multiply_1", + "__module.model.backbone.stage0.0/aten::add_/Add_1", + "__module.model.backbone.stage0.1.layers.0.cross_resolution_weighting/aten::mul/Multiply", + "__module.model.backbone.stage0.1.layers.0.cross_resolution_weighting/aten::mul/Multiply_1", + "__module.model.backbone.stage0.1.layers.1.cross_resolution_weighting/aten::mul/Multiply", + "__module.model.backbone.stage0.1.layers.1.cross_resolution_weighting/aten::mul/Multiply_1", + "__module.model.backbone.stage0.1/aten::add_/Add_1", + "__module.model.backbone.stage1.0.layers.0.cross_resolution_weighting/aten::mul/Multiply", + "__module.model.backbone.stage1.0.layers.0.cross_resolution_weighting/aten::mul/Multiply_1", + "__module.model.backbone.stage1.0.layers.0.cross_resolution_weighting/aten::mul/Multiply_2", + "__module.model.backbone.stage1.0.layers.1.cross_resolution_weighting/aten::mul/Multiply", + "__module.model.backbone.stage1.0.layers.1.cross_resolution_weighting/aten::mul/Multiply_1", + "__module.model.backbone.stage1.0/aten::add_/Add_1", + "__module.model.backbone.stage1.0.layers.1.cross_resolution_weighting/aten::mul/Multiply_2", + "__module.model.backbone.stage1.0/aten::add_/Add_2", + "__module.model.backbone.stage1.0/aten::add_/Add_5", + "__module.model.backbone.stage1.1.layers.0.cross_resolution_weighting/aten::mul/Multiply", + "__module.model.backbone.stage1.1.layers.0.cross_resolution_weighting/aten::mul/Multiply_1", + "__module.model.backbone.stage1.1.layers.0.cross_resolution_weighting/aten::mul/Multiply_2", + "__module.model.backbone.stage1.1.layers.1.cross_resolution_weighting/aten::mul/Multiply", + "__module.model.backbone.stage1.1.layers.1.cross_resolution_weighting/aten::mul/Multiply_1", + "__module.model.backbone.stage1.1/aten::add_/Add_1", + "__module.model.backbone.stage1.1.layers.1.cross_resolution_weighting/aten::mul/Multiply_2", + "__module.model.backbone.stage1.1/aten::add_/Add_2", + "__module.model.backbone.stage1.1/aten::add_/Add_5", + "__module.model.backbone.stage1.2.layers.0.cross_resolution_weighting/aten::mul/Multiply", + "__module.model.backbone.stage1.2.layers.0.cross_resolution_weighting/aten::mul/Multiply_1", + "__module.model.backbone.stage1.2.layers.0.cross_resolution_weighting/aten::mul/Multiply_2", + "__module.model.backbone.stage1.2.layers.1.cross_resolution_weighting/aten::mul/Multiply", + "__module.model.backbone.stage1.2.layers.1.cross_resolution_weighting/aten::mul/Multiply_1", + "__module.model.backbone.stage1.2/aten::add_/Add_1", + "__module.model.backbone.stage1.2.layers.1.cross_resolution_weighting/aten::mul/Multiply_2", + "__module.model.backbone.stage1.2/aten::add_/Add_2", + "__module.model.backbone.stage1.2/aten::add_/Add_5", + "__module.model.backbone.stage1.3.layers.0.cross_resolution_weighting/aten::mul/Multiply", + "__module.model.backbone.stage1.3.layers.0.cross_resolution_weighting/aten::mul/Multiply_1", + "__module.model.backbone.stage1.3.layers.0.cross_resolution_weighting/aten::mul/Multiply_2", + "__module.model.backbone.stage1.3.layers.1.cross_resolution_weighting/aten::mul/Multiply", + "__module.model.backbone.stage1.3.layers.1.cross_resolution_weighting/aten::mul/Multiply_1", + "__module.model.backbone.stage1.3/aten::add_/Add_1", + "__module.model.backbone.stage1.3.layers.1.cross_resolution_weighting/aten::mul/Multiply_2", + "__module.model.backbone.stage1.3/aten::add_/Add_2", + "__module.model.backbone.stage1.3/aten::add_/Add_5", + "__module.model.backbone.stage2.0.layers.0.cross_resolution_weighting/aten::mul/Multiply", + "__module.model.backbone.stage2.0.layers.0.cross_resolution_weighting/aten::mul/Multiply_1", + "__module.model.backbone.stage2.0.layers.0.cross_resolution_weighting/aten::mul/Multiply_2", + "__module.model.backbone.stage2.0.layers.0.cross_resolution_weighting/aten::mul/Multiply_3", + "__module.model.backbone.stage2.0.layers.1.cross_resolution_weighting/aten::mul/Multiply", + "__module.model.backbone.stage2.0.layers.1.cross_resolution_weighting/aten::mul/Multiply_1", + "__module.model.backbone.stage2.0/aten::add_/Add_1", + "__module.model.backbone.stage2.0.layers.1.cross_resolution_weighting/aten::mul/Multiply_2", + "__module.model.backbone.stage2.0/aten::add_/Add_2", + "__module.model.backbone.stage2.0.layers.1.cross_resolution_weighting/aten::mul/Multiply_3", + "__module.model.backbone.stage2.0/aten::add_/Add_3", + "__module.model.backbone.stage2.0/aten::add_/Add_6", + "__module.model.backbone.stage2.0/aten::add_/Add_7", + "__module.model.backbone.stage2.0/aten::add_/Add_11", + "__module.model.backbone.stage2.1.layers.0.cross_resolution_weighting/aten::mul/Multiply", + "__module.model.backbone.stage2.1.layers.0.cross_resolution_weighting/aten::mul/Multiply_1", + "__module.model.backbone.stage2.1.layers.0.cross_resolution_weighting/aten::mul/Multiply_2", + "__module.model.backbone.stage2.1.layers.0.cross_resolution_weighting/aten::mul/Multiply_3", + "__module.model.backbone.stage2.1.layers.1.cross_resolution_weighting/aten::mul/Multiply", + "__module.model.backbone.stage2.1.layers.1.cross_resolution_weighting/aten::mul/Multiply_1", + "__module.model.backbone.stage2.1/aten::add_/Add_1", + "__module.model.backbone.stage2.1.layers.1.cross_resolution_weighting/aten::mul/Multiply_2", + "__module.model.backbone.stage2.1/aten::add_/Add_2", + "__module.model.backbone.stage2.1.layers.1.cross_resolution_weighting/aten::mul/Multiply_3", + "__module.model.backbone.stage2.1/aten::add_/Add_3", + "__module.model.backbone.stage2.1/aten::add_/Add_6", + "__module.model.backbone.stage2.1/aten::add_/Add_7", + "__module.model.backbone.stage2.1/aten::add_/Add_11", + "__module.model.decode_head.aggregator/aten::add/Add", + "__module.model.decode_head.aggregator/aten::add/Add_1", + "__module.model.decode_head.aggregator/aten::add/Add_2", + "__module.model.backbone.stage2.1/aten::add_/Add", + ], + }, + "preset": "mixed", + } + + if self.model_version == "small": + return { + "ignored_scope": { + "names": [ + "__module.model.backbone.stage0.0.layers.0.cross_resolution_weighting/aten::mul/Multiply", + "__module.model.backbone.stage0.0.layers.0.cross_resolution_weighting/aten::mul/Multiply_1", + "__module.model.backbone.stage0.0.layers.1.cross_resolution_weighting/aten::mul/Multiply", + "__module.model.backbone.stage0.0.layers.1.cross_resolution_weighting/aten::mul/Multiply_1", + "__module.model.backbone.stage0.0/aten::add_/Add_1", + "__module.model.backbone.stage0.1.layers.0.cross_resolution_weighting/aten::mul/Multiply", + "__module.model.backbone.stage0.1.layers.0.cross_resolution_weighting/aten::mul/Multiply_1", + "__module.model.backbone.stage0.1.layers.1.cross_resolution_weighting/aten::mul/Multiply", + "__module.model.backbone.stage0.1.layers.1.cross_resolution_weighting/aten::mul/Multiply_1", + "__module.model.backbone.stage0.1/aten::add_/Add_1", + "__module.model.backbone.stage0.2.layers.0.cross_resolution_weighting/aten::mul/Multiply", + "__module.model.backbone.stage0.2.layers.0.cross_resolution_weighting/aten::mul/Multiply_1", + "__module.model.backbone.stage0.2.layers.1.cross_resolution_weighting/aten::mul/Multiply", + "__module.model.backbone.stage0.2.layers.1.cross_resolution_weighting/aten::mul/Multiply_1", + "__module.model.backbone.stage0.2/aten::add_/Add_1", + "__module.model.backbone.stage0.3.layers.0.cross_resolution_weighting/aten::mul/Multiply", + "__module.model.backbone.stage0.3.layers.0.cross_resolution_weighting/aten::mul/Multiply_1", + "__module.model.backbone.stage0.3.layers.1.cross_resolution_weighting/aten::mul/Multiply", + "__module.model.backbone.stage0.3.layers.1.cross_resolution_weighting/aten::mul/Multiply_1", + "__module.model.backbone.stage0.3/aten::add_/Add_1", + "__module.model.backbone.stage1.0.layers.0.cross_resolution_weighting/aten::mul/Multiply", + "__module.model.backbone.stage1.0.layers.0.cross_resolution_weighting/aten::mul/Multiply_1", + "__module.model.backbone.stage1.0.layers.0.cross_resolution_weighting/aten::mul/Multiply_2", + "__module.model.backbone.stage1.0.layers.1.cross_resolution_weighting/aten::mul/Multiply", + "__module.model.backbone.stage1.0.layers.1.cross_resolution_weighting/aten::mul/Multiply_1", + "__module.model.backbone.stage1.0/aten::add_/Add_1", + "__module.model.backbone.stage1.0.layers.1.cross_resolution_weighting/aten::mul/Multiply_2", + "__module.model.backbone.stage1.0/aten::add_/Add_2", + "__module.model.backbone.stage1.0/aten::add_/Add_5", + "__module.model.backbone.stage1.1.layers.0.cross_resolution_weighting/aten::mul/Multiply", + "__module.model.backbone.stage1.1.layers.0.cross_resolution_weighting/aten::mul/Multiply_1", + "__module.model.backbone.stage1.1.layers.0.cross_resolution_weighting/aten::mul/Multiply_2", + "__module.model.backbone.stage1.1.layers.1.cross_resolution_weighting/aten::mul/Multiply", + "__module.model.backbone.stage1.1.layers.1.cross_resolution_weighting/aten::mul/Multiply_1", + "__module.model.backbone.stage1.1/aten::add_/Add_1", + "__module.model.backbone.stage1.1.layers.1.cross_resolution_weighting/aten::mul/Multiply_2", + "__module.model.backbone.stage1.1/aten::add_/Add_2", + "__module.model.backbone.stage1.1/aten::add_/Add_5", + "__module.model.backbone.stage1.2.layers.0.cross_resolution_weighting/aten::mul/Multiply", + "__module.model.backbone.stage1.2.layers.0.cross_resolution_weighting/aten::mul/Multiply_1", + "__module.model.backbone.stage1.2.layers.0.cross_resolution_weighting/aten::mul/Multiply_2", + "__module.model.backbone.stage1.2.layers.1.cross_resolution_weighting/aten::mul/Multiply", + "__module.model.backbone.stage1.2.layers.1.cross_resolution_weighting/aten::mul/Multiply_1", + "__module.model.backbone.stage1.2/aten::add_/Add_1", + "__module.model.backbone.stage1.2.layers.1.cross_resolution_weighting/aten::mul/Multiply_2", + "__module.model.backbone.stage1.2/aten::add_/Add_2", + "__module.model.backbone.stage1.2/aten::add_/Add_5", + "__module.model.backbone.stage1.3.layers.0.cross_resolution_weighting/aten::mul/Multiply", + "__module.model.backbone.stage1.3.layers.0.cross_resolution_weighting/aten::mul/Multiply_1", + "__module.model.backbone.stage1.3.layers.0.cross_resolution_weighting/aten::mul/Multiply_2", + "__module.model.backbone.stage1.3.layers.1.cross_resolution_weighting/aten::mul/Multiply", + "__module.model.backbone.stage1.3.layers.1.cross_resolution_weighting/aten::mul/Multiply_1", + "__module.model.backbone.stage1.3/aten::add_/Add_1", + "__module.model.backbone.stage1.3.layers.1.cross_resolution_weighting/aten::mul/Multiply_2", + "__module.model.backbone.stage1.3/aten::add_/Add_2", + "__module.model.backbone.stage1.3/aten::add_/Add_5", + "__module.model.decode_head.aggregator/aten::add/Add", + "__module.model.decode_head.aggregator/aten::add/Add_1", + ], + }, + "preset": "mixed", + } -class LiteHRNetSemiSL(OTXLiteHRNet): - """LiteHRNetSemiSL Model.""" - - def _create_model(self) -> nn.Module: - litehrnet_model_class = LITEHRNET_VARIANTS[self.name_base_model] - # merge configurations with defaults overriding them - backbone_configuration = litehrnet_model_class.default_backbone_configuration | self.backbone_configuration - decode_head_configuration = ( - litehrnet_model_class.default_decode_head_configuration | self.decode_head_configuration - ) - # initialize backbones - backbone = LiteHRNet(**backbone_configuration) - decode_head = FCNHead(num_classes=self.num_classes, **decode_head_configuration) - - base_model = litehrnet_model_class( - backbone=backbone, - decode_head=decode_head, - criterion_configuration=self.criterion_configuration, - ) - - return MeanTeacher(base_model, unsup_weight=0.7, drop_unrel_pixels_percent=20, semisl_start_epoch=2) - - def _customize_inputs(self, entity: SegBatchDataEntity) -> dict[str, Any]: - if not isinstance(entity, dict): - if self.training: - msg = "unlabeled inputs should be provided for semi-sl training" - raise RuntimeError(msg) - return super()._customize_inputs(entity) - - entity["labeled"].masks = torch.stack(entity["labeled"].masks).long() - w_u_images = entity["weak_transforms"].images - s_u_images = entity["strong_transforms"].images - unlabeled_img_metas = entity["weak_transforms"].imgs_info - labeled_inputs = entity["labeled"] - - return { - "inputs": labeled_inputs.images, - "unlabeled_weak_images": w_u_images, - "unlabeled_strong_images": s_u_images, - "global_step": self.trainer.global_step, - "steps_per_epoch": self.trainer.num_training_batches, - "img_metas": labeled_inputs.imgs_info, - "unlabeled_img_metas": unlabeled_img_metas, - "masks": labeled_inputs.masks, - "mode": "loss", - } + return {} diff --git a/src/otx/algo/segmentation/losses/__init__.py b/src/otx/algo/segmentation/losses/__init__.py index 7efddcd87ba..801a5863ac3 100644 --- a/src/otx/algo/segmentation/losses/__init__.py +++ b/src/otx/algo/segmentation/losses/__init__.py @@ -2,35 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 # """Custom Losses for OTX segmentation model.""" -from __future__ import annotations - -from typing import TYPE_CHECKING from .cross_entropy_loss_with_ignore import CrossEntropyLossWithIgnore -if TYPE_CHECKING: - from torch import nn - -__all__ = ["CrossEntropyLossWithIgnore", "create_criterion"] - - -def create_criterion(losses: list | str, params: dict | None = None) -> nn.Module: - """Create loss function by name.""" - if isinstance(losses, list): - creterions: list = [] - for loss in losses: - loss_type = loss["type"] - params = loss.get("params", {}) - creterions.append(create_criterion(loss_type, params)) - return creterions - - if isinstance(losses, str): - params = {} if params is None else params - if losses == "CrossEntropyLoss": - return CrossEntropyLossWithIgnore(**params) - - msg = f"Unknown loss type: {losses}" - raise ValueError(msg) - - msg = "losses should be a dict or a string" - raise ValueError(msg) +__all__ = ["CrossEntropyLossWithIgnore"] diff --git a/src/otx/algo/segmentation/modules/__init__.py b/src/otx/algo/segmentation/modules/__init__.py index ba01f4c4d51..cc0f41e2bd1 100644 --- a/src/otx/algo/segmentation/modules/__init__.py +++ b/src/otx/algo/segmentation/modules/__init__.py @@ -5,13 +5,10 @@ from .aggregators import IterativeAggregator -from .blocks import AsymmetricPositionAttentionModule, LocalAttentionModule from .utils import channel_shuffle, normalize, resize __all__ = [ - "AsymmetricPositionAttentionModule", "IterativeAggregator", - "LocalAttentionModule", "channel_shuffle", "resize", "normalize", diff --git a/src/otx/algo/segmentation/modules/blocks.py b/src/otx/algo/segmentation/modules/blocks.py index 240924ab476..9128bb30cf6 100644 --- a/src/otx/algo/segmentation/modules/blocks.py +++ b/src/otx/algo/segmentation/modules/blocks.py @@ -5,119 +5,9 @@ from __future__ import annotations -from typing import Callable, ClassVar +from typing import Callable import torch -import torch.nn.functional as f -from torch import nn -from torch.nn import AdaptiveAvgPool2d, AdaptiveMaxPool2d - -from otx.algo.modules import Conv2dModule - - -class PSPModule(nn.Module): - """PSP module. - - Reference: https://github.com/MendelXu/ANN. - """ - - methods: ClassVar[dict[str, AdaptiveMaxPool2d | AdaptiveAvgPool2d]] = { - "max": AdaptiveMaxPool2d, - "avg": AdaptiveAvgPool2d, - } - - def __init__(self, sizes: tuple = (1, 3, 6, 8), method: str = "max"): - super().__init__() - - pool_block = self.methods[method] - - self.stages = nn.ModuleList([pool_block(output_size=(size, size)) for size in sizes]) - - def forward(self, feats: torch.Tensor) -> torch.Tensor: - """Forward.""" - batch_size, c, _, _ = feats.size() - - priors = [stage(feats).view(batch_size, c, -1) for stage in self.stages] - - return torch.cat(priors, -1) - - -class AsymmetricPositionAttentionModule(nn.Module): - """AsymmetricPositionAttentionModule. - - Reference: https://github.com/MendelXu/ANN. - """ - - def __init__( - self, - in_channels: int, - key_channels: int, - value_channels: int | None = None, - psp_size: tuple | None = None, - norm_cfg: dict | None = None, - ): - super().__init__() - - self.in_channels = in_channels - self.key_channels = key_channels - self.value_channels = value_channels if value_channels is not None else in_channels - if norm_cfg is None: - norm_cfg = {"type": "BN"} - if psp_size is None: - psp_size = (1, 3, 6, 8) - self.norm_cfg = norm_cfg - self.query_key = Conv2dModule( - in_channels=self.in_channels, - out_channels=self.key_channels, - kernel_size=1, - stride=1, - padding=0, - norm_cfg=self.norm_cfg, - activation_callable=nn.ReLU, - ) - self.key_psp = PSPModule(psp_size, method="max") - - self.value = Conv2dModule( - in_channels=self.in_channels, - out_channels=self.value_channels, - kernel_size=1, - stride=1, - padding=0, - norm_cfg=self.norm_cfg, - activation_callable=nn.ReLU, - ) - self.value_psp = PSPModule(psp_size, method="max") - - self.out_conv = Conv2dModule( - in_channels=self.value_channels, - out_channels=self.in_channels, - kernel_size=1, - stride=1, - padding=0, - norm_cfg=self.norm_cfg, - activation_callable=None, - ) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """Forward.""" - batch_size, _, _ = x.size(0), x.size(2), x.size(3) - - query_key = self.query_key(x) - - key = self.key_psp(query_key) - value = self.value_psp(self.value(x)).permute(0, 2, 1) - query = query_key.view(batch_size, self.key_channels, -1).permute(0, 2, 1) - - similarity_scores = torch.matmul(query, key) - similarity_scores = (self.key_channels**-0.5) * similarity_scores - similarity_scores = f.softmax(similarity_scores, dim=-1) - - y = torch.matmul(similarity_scores, value) - y = y.permute(0, 2, 1).contiguous() - y = y.view(batch_size, self.value_channels, *x.size()[2:]) - y = self.out_conv(y) - - return x + y class OnnxLpNormalization(torch.autograd.Function): @@ -147,62 +37,3 @@ def symbolic( """Symbolic onnxLpNormalization.""" del eps # These args are not used. return g.op("LpNormalization", x, axis_i=int(axis), p_i=int(p)) - - -class LocalAttentionModule(nn.Module): - """LocalAttentionModule. - - Reference: https://github.com/lxtGH/GALD-DGCNet. - """ - - def __init__(self, num_channels: int, norm_cfg: dict | None = None): - if norm_cfg is None: - norm_cfg = {"type": "BN"} - super().__init__() - - self.num_channels = num_channels - self.norm_cfg = norm_cfg - - self.dwconv1 = Conv2dModule( - in_channels=self.num_channels, - out_channels=self.num_channels, - kernel_size=3, - stride=2, - padding=1, - groups=self.num_channels, - norm_cfg=self.norm_cfg, - activation_callable=nn.ReLU, - ) - self.dwconv2 = Conv2dModule( - in_channels=self.num_channels, - out_channels=self.num_channels, - kernel_size=3, - stride=2, - padding=1, - groups=self.num_channels, - norm_cfg=self.norm_cfg, - activation_callable=nn.ReLU, - ) - self.dwconv3 = Conv2dModule( - in_channels=self.num_channels, - out_channels=self.num_channels, - kernel_size=3, - stride=2, - padding=1, - groups=self.num_channels, - norm_cfg=self.norm_cfg, - activation_callable=nn.ReLU, - ) - self.sigmoid_spatial = nn.Sigmoid() - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """Forward.""" - _, _, h, w = x.size() - - y = self.dwconv1(x) - y = self.dwconv2(y) - y = self.dwconv3(y) - y = f.interpolate(y, size=(h, w), mode="bilinear", align_corners=True) - mask = self.sigmoid_spatial(y) - - return x + x * mask diff --git a/src/otx/algo/segmentation/segmentors/base_model.py b/src/otx/algo/segmentation/segmentors/base_model.py index c979be625ae..c66c49f84f5 100644 --- a/src/otx/algo/segmentation/segmentors/base_model.py +++ b/src/otx/algo/segmentation/segmentors/base_model.py @@ -5,44 +5,36 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING import torch.nn.functional as f from torch import Tensor, nn -from otx.algo.segmentation.losses import create_criterion - if TYPE_CHECKING: from otx.core.data.entity.base import ImageInfo class BaseSegmModel(nn.Module): - """Base Segmentation Model.""" + """Base Segmentation Model. + + Args: + backbone (nn.Module): The backbone of the segmentation model. + decode_head (nn.Module): The decode head of the segmentation model. + criterion (nn.Module, optional): The criterion of the model. Defaults to None. + If None, use CrossEntropyLoss with ignore_index=255. + """ def __init__( self, backbone: nn.Module, decode_head: nn.Module, - criterion_configuration: list[dict[str, str | Any]] | None = None, + criterion: nn.Module | None = None, ) -> None: - """Initializes a segmentation model. - - Args: - backbone (nn.Module): The backbone of the segmentation model. - decode_head (nn.Module): The decode head of the segmentation model. - criterion_configuration (Dict[str, str | Any]): The criterion of the model. - If None, use CrossEntropyLoss with ignore_index=255. - - Returns: - None - """ super().__init__() - if criterion_configuration is None: - criterion_configuration = [{"type": "CrossEntropyLoss", "params": {"ignore_index": 255}}] + self.criterion = nn.CrossEntropyLoss(ignore_index=255) if criterion is None else criterion self.backbone = backbone self.decode_head = decode_head - self.criterions = create_criterion(criterion_configuration) def forward( self, @@ -66,8 +58,7 @@ def forward( - If mode is "predict", returns the predicted outputs. - Otherwise, returns the model outputs after interpolation. """ - enc_feats = self.backbone(inputs) - outputs = self.decode_head(enc_feats) + outputs = self.extract_features(inputs) outputs = f.interpolate(outputs, size=inputs.size()[2:], mode="bilinear", align_corners=True) if mode == "tensor": @@ -87,6 +78,11 @@ def forward( return outputs + def extract_features(self, inputs: Tensor) -> Tensor: + """Extract features from the backbone and head.""" + enc_feats = self.backbone(inputs) + return self.decode_head(enc_feats) + def calculate_loss( self, model_features: Tensor, @@ -112,22 +108,21 @@ def calculate_loss( # class incremental training valid_label_mask = self.get_valid_label_mask(img_metas) output_losses = {} - for criterion in self.criterions: - valid_label_mask_cfg = {} - if criterion.name == "loss_ce_ignore": - valid_label_mask_cfg["valid_label_mask"] = valid_label_mask - if criterion.name not in output_losses: - output_losses[criterion.name] = criterion( - outputs, - masks, - **valid_label_mask_cfg, - ) - else: - output_losses[criterion.name] += criterion( - outputs, - masks, - **valid_label_mask_cfg, - ) + valid_label_mask_cfg = {} + if self.criterion.name == "loss_ce_ignore": + valid_label_mask_cfg["valid_label_mask"] = valid_label_mask + if self.criterion.name not in output_losses: + output_losses[self.criterion.name] = self.criterion( + outputs, + masks, + **valid_label_mask_cfg, + ) + else: + output_losses[self.criterion.name] += self.criterion( + outputs, + masks, + **valid_label_mask_cfg, + ) return output_losses def get_valid_label_mask(self, img_metas: list[ImageInfo]) -> list[Tensor]: diff --git a/src/otx/algo/segmentation/segnext.py b/src/otx/algo/segmentation/segnext.py index b3e28beda14..7d3445a959e 100644 --- a/src/otx/algo/segmentation/segnext.py +++ b/src/otx/algo/segmentation/segnext.py @@ -2,164 +2,43 @@ # SPDX-License-Identifier: Apache-2.0 # """SegNext model implementations.""" - from __future__ import annotations from typing import TYPE_CHECKING, Any, ClassVar -import torch -from torch import nn - from otx.algo.segmentation.backbones import MSCAN from otx.algo.segmentation.heads import LightHamHead -from otx.algo.segmentation.segmentors import BaseSegmModel, MeanTeacher +from otx.algo.segmentation.losses import CrossEntropyLossWithIgnore +from otx.algo.segmentation.segmentors import BaseSegmModel from otx.algo.utils.support_otx_v1 import OTXv1Helper -from otx.core.data.entity.segmentation import SegBatchDataEntity -from otx.core.metrics.dice import SegmCallable -from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable -from otx.core.model.segmentation import TorchVisionCompatibleModel +from otx.core.model.segmentation import OTXSegmentationModel if TYPE_CHECKING: - from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable - - from otx.core.metrics import MetricCallable - from otx.core.schedulers import LRSchedulerListCallable - from otx.core.types.label import LabelInfoTypes - - -class SegNextB(BaseSegmModel): - """SegNextB Model.""" - - default_backbone_configuration: ClassVar[dict[str, Any]] = { - "activation_callable": nn.GELU, - "attention_kernel_paddings": [2, [0, 3], [0, 5], [0, 10]], - "attention_kernel_sizes": [5, [1, 7], [1, 11], [1, 21]], - "depths": [3, 3, 12, 3], - "drop_path_rate": 0.1, - "drop_rate": 0.0, - "embed_dims": [64, 128, 320, 512], - "mlp_ratios": [8, 8, 4, 4], - "norm_cfg": {"requires_grad": True, "type": "BN"}, - "pretrained_weights": "https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/segnext/mscan_b_20230227-3ab7d230.pth", - } - default_decode_head_configuration: ClassVar[dict[str, Any]] = { - "ham_kwargs": {"md_r": 16, "md_s": 1, "eval_steps": 7, "train_steps": 6}, - "in_channels": [128, 320, 512], - "in_index": [1, 2, 3], - "norm_cfg": {"num_groups": 32, "requires_grad": True, "type": "GN"}, - "align_corners": False, - "channels": 512, - "dropout_ratio": 0.1, - "ham_channels": 512, - } - - -class SegNextS(BaseSegmModel): - """SegNextS Model.""" + from torch import nn - default_backbone_configuration: ClassVar[dict[str, Any]] = { - "activation_callable": nn.GELU, - "attention_kernel_paddings": [2, [0, 3], [0, 5], [0, 10]], - "attention_kernel_sizes": [5, [1, 7], [1, 11], [1, 21]], - "depths": [2, 2, 4, 2], - "drop_path_rate": 0.1, - "drop_rate": 0.0, - "embed_dims": [64, 128, 320, 512], - "mlp_ratios": [8, 8, 4, 4], - "norm_cfg": {"requires_grad": True, "type": "BN"}, - "pretrained_weights": "https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/segnext/mscan_s_20230227-f33ccdf2.pth", - } - default_decode_head_configuration: ClassVar[dict[str, Any]] = { - "norm_cfg": {"num_groups": 32, "requires_grad": True, "type": "GN"}, - "ham_kwargs": {"md_r": 16, "md_s": 1, "eval_steps": 7, "rand_init": True, "train_steps": 6}, - "in_channels": [128, 320, 512], - "in_index": [1, 2, 3], - "align_corners": False, - "channels": 256, - "dropout_ratio": 0.1, - "ham_channels": 256, - } - -class SegNextT(BaseSegmModel): - """SegNextT Model.""" - - default_backbone_configuration: ClassVar[dict[str, Any]] = { - "activation_callable": nn.GELU, - "attention_kernel_paddings": [2, [0, 3], [0, 5], [0, 10]], - "attention_kernel_sizes": [5, [1, 7], [1, 11], [1, 21]], - "depths": [3, 3, 5, 2], - "drop_path_rate": 0.1, - "drop_rate": 0.0, - "embed_dims": [32, 64, 160, 256], - "mlp_ratios": [8, 8, 4, 4], - "norm_cfg": {"requires_grad": True, "type": "BN"}, - "pretrained_weights": "https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/segnext/mscan_t_20230227-119e8c9f.pth", - } - default_decode_head_configuration: ClassVar[dict[str, Any]] = { - "ham_kwargs": {"md_r": 16, "md_s": 1, "eval_steps": 7, "rand_init": True, "train_steps": 6}, - "norm_cfg": {"num_groups": 32, "requires_grad": True, "type": "GN"}, - "in_channels": [64, 160, 256], - "in_index": [1, 2, 3], - "align_corners": False, - "channels": 256, - "dropout_ratio": 0.1, - "ham_channels": 256, - } - - -SEGNEXT_VARIANTS = { - "SegNextB": SegNextB, - "SegNextS": SegNextS, - "SegNextT": SegNextT, -} - - -class OTXSegNext(TorchVisionCompatibleModel): +class SegNext(OTXSegmentationModel): """SegNext Model.""" - def __init__( - self, - label_info: LabelInfoTypes, - input_size: tuple[int, int] = (512, 512), - optimizer: OptimizerCallable = DefaultOptimizerCallable, - scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable, - metric: MetricCallable = SegmCallable, # type: ignore[assignment] - torch_compile: bool = False, - backbone_configuration: dict[str, Any] | None = None, - decode_head_configuration: dict[str, Any] | None = None, - criterion_configuration: list[dict[str, Any]] | None = None, - export_image_configuration: dict[str, Any] | None = None, - name_base_model: str = "semantic_segmentation_model", - ): - super().__init__( - label_info=label_info, - input_size=input_size, - optimizer=optimizer, - scheduler=scheduler, - metric=metric, - torch_compile=torch_compile, - backbone_configuration=backbone_configuration, - decode_head_configuration=decode_head_configuration, - criterion_configuration=criterion_configuration, - export_image_configuration=export_image_configuration, - name_base_model=name_base_model, - ) + AVAILABLE_MODEL_VERSIONS: ClassVar[list[str]] = [ + "segnext_tiny", + "segnext_small", + "segnext_base", + ] - def _create_model(self) -> nn.Module: - segnext_model_class = SEGNEXT_VARIANTS[self.name_base_model] - # merge configurations with defaults overriding them - backbone_configuration = segnext_model_class.default_backbone_configuration | self.backbone_configuration - decode_head_configuration = ( - segnext_model_class.default_decode_head_configuration | self.decode_head_configuration - ) + def _build_model(self) -> nn.Module: # initialize backbones - backbone = MSCAN(**backbone_configuration) - decode_head = LightHamHead(num_classes=self.num_classes, **decode_head_configuration) - return segnext_model_class( + if self.model_version not in self.AVAILABLE_MODEL_VERSIONS: + msg = f"Model version {self.model_version} is not supported." + raise ValueError(msg) + + backbone = MSCAN(version=self.model_version) + decode_head = LightHamHead(version=self.model_version, num_classes=self.num_classes) + criterion = CrossEntropyLossWithIgnore(ignore_index=self.label_info.ignore_index) # type: ignore[attr-defined] + return BaseSegmModel( backbone=backbone, decode_head=decode_head, - criterion_configuration=self.criterion_configuration, + criterion=criterion, ) def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.model.") -> dict: @@ -181,49 +60,3 @@ def _optimization_config(self) -> dict[str, Any]: ], }, } - - -class SemiSLSegNext(OTXSegNext): - """SegNext Model.""" - - def _customize_inputs(self, entity: SegBatchDataEntity) -> dict[str, Any]: - if not isinstance(entity, dict): - if self.training: - msg = "unlabeled inputs should be provided for semi-sl training" - raise RuntimeError(msg) - return super()._customize_inputs(entity) - - entity["labeled"].masks = torch.stack(entity["labeled"].masks).long() - w_u_images = entity["weak_transforms"].images - s_u_images = entity["strong_transforms"].images - unlabeled_img_metas = entity["weak_transforms"].imgs_info - labeled_inputs = entity["labeled"] - - return { - "inputs": labeled_inputs.images, - "unlabeled_weak_images": w_u_images, - "unlabeled_strong_images": s_u_images, - "global_step": self.trainer.global_step, - "steps_per_epoch": self.trainer.num_training_batches, - "img_metas": labeled_inputs.imgs_info, - "unlabeled_img_metas": unlabeled_img_metas, - "masks": labeled_inputs.masks, - "mode": "loss", - } - - def _create_model(self) -> nn.Module: - segnext_model_class = SEGNEXT_VARIANTS[self.name_base_model] - # merge configurations with defaults overriding them - backbone_configuration = segnext_model_class.default_backbone_configuration | self.backbone_configuration - decode_head_configuration = ( - segnext_model_class.default_decode_head_configuration | self.decode_head_configuration - ) - # initialize backbones - backbone = MSCAN(**backbone_configuration) - decode_head = LightHamHead(num_classes=self.num_classes, **decode_head_configuration) - base_model = segnext_model_class( - backbone=backbone, - decode_head=decode_head, - criterion_configuration=self.criterion_configuration, - ) - return MeanTeacher(base_model, unsup_weight=0.7, drop_unrel_pixels_percent=20, semisl_start_epoch=2) diff --git a/src/otx/core/model/base.py b/src/otx/core/model/base.py index bd42c668a52..5a1cdcac64c 100644 --- a/src/otx/core/model/base.py +++ b/src/otx/core/model/base.py @@ -49,6 +49,7 @@ from otx.core.types.export import OTXExportFormatType, TaskLevelExportParameters from otx.core.types.label import LabelInfo, LabelInfoTypes, NullLabelInfo from otx.core.types.precision import OTXPrecisionType +from otx.core.types.task import OTXTrainType from otx.core.utils.build import get_default_num_async_infer_requests from otx.core.utils.miscellaneous import ensure_callable from otx.core.utils.utils import is_ckpt_for_finetuning, is_ckpt_from_otx_v1, remove_state_dict_prefix @@ -113,10 +114,12 @@ def __init__( metric: MetricCallable = NullMetricCallable, torch_compile: bool = False, tile_config: TileConfig = TileConfig(enable_tiler=False), + train_type: Literal[OTXTrainType.SUPERVISED, OTXTrainType.SEMI_SUPERVISED] = OTXTrainType.SUPERVISED, ) -> None: super().__init__() self._label_info = self._dispatch_label_info(label_info) + self.train_type = train_type self._check_input_size(input_size) self.input_size = input_size self.classification_layers: dict[str, dict[str, Any]] = {} diff --git a/src/otx/core/model/classification.py b/src/otx/core/model/classification.py index f48f026b585..bed73e975e4 100644 --- a/src/otx/core/model/classification.py +++ b/src/otx/core/model/classification.py @@ -53,8 +53,6 @@ def __init__( torch_compile: bool = False, train_type: Literal[OTXTrainType.SUPERVISED, OTXTrainType.SEMI_SUPERVISED] = OTXTrainType.SUPERVISED, ) -> None: - self.train_type = train_type - super().__init__( label_info=label_info, input_size=input_size, @@ -62,6 +60,7 @@ def __init__( scheduler=scheduler, metric=metric, torch_compile=torch_compile, + train_type=train_type, ) self.input_size: tuple[int, int] diff --git a/src/otx/core/model/segmentation.py b/src/otx/core/model/segmentation.py index cddc7beb2fc..330deb89bec 100644 --- a/src/otx/core/model/segmentation.py +++ b/src/otx/core/model/segmentation.py @@ -6,12 +6,15 @@ from __future__ import annotations import json +from abc import abstractmethod from collections.abc import Sequence -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, ClassVar, Literal import torch +from torch import nn from torchvision import tv_tensors +from otx.algo.segmentation.segmentors import MeanTeacher from otx.core.data.entity.base import ImageInfo, OTXBatchLossEntity from otx.core.data.entity.segmentation import SegBatchDataEntity, SegBatchPredEntity from otx.core.exporter.base import OTXModelExporter @@ -23,6 +26,7 @@ from otx.core.types.export import OTXExportFormatType, TaskLevelExportParameters from otx.core.types.label import LabelInfo, LabelInfoTypes, SegLabelInfo from otx.core.types.precision import OTXPrecisionType +from otx.core.types.task import OTXTrainType if TYPE_CHECKING: from pathlib import Path @@ -37,14 +41,22 @@ class OTXSegmentationModel(OTXModel[SegBatchDataEntity, SegBatchPredEntity]): """Base class for the semantic segmentation models used in OTX.""" + mean: ClassVar[tuple[float, float, float]] = (0.485, 0.456, 0.406) + scale: ClassVar[tuple[float, float, float]] = (0.229, 0.224, 0.225) + def __init__( self, label_info: LabelInfoTypes, - input_size: tuple[int, int], + input_size: tuple[int, int] = (512, 512), optimizer: OptimizerCallable = DefaultOptimizerCallable, scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable, metric: MetricCallable = SegmCallable, # type: ignore[assignment] torch_compile: bool = False, + train_type: Literal[OTXTrainType.SUPERVISED, OTXTrainType.SEMI_SUPERVISED] = OTXTrainType.SUPERVISED, + model_version: str | None = None, + unsupervised_weight: float = 0.7, + semisl_start_epoch: int = 2, + drop_unreliable_pixels_percent: int = 20, ): """Base semantic segmentation model. @@ -59,7 +71,21 @@ def __init__( Defaults to SegmCallable. torch_compile (bool, optional): Whether to compile the model using TorchScript. Defaults to False. + train_type (Literal[OTXTrainType.SUPERVISED, OTXTrainType.SEMI_SUPERVISED], optional): + The training type of the model. Defaults to OTXTrainType.SUPERVISED. + model_version (str | None, optional): The version of the model. Defaults to None. + unsupervised_weight (float, optional): The weight of the unsupervised loss. + Only for semi-supervised learning. Defaults to 0.7. + semisl_start_epoch (int, optional): The epoch at which the semi-supervised learning starts. + Only for semi-supervised learning. Defaults to 2. + drop_unreliable_pixels_percent (int, optional): The percentage of unreliable pixels to drop. + Only for semi-supervised learning. Defaults to 20. """ + self.model_version = model_version + self.unsupervised_weight = unsupervised_weight + self.semisl_start_epoch = semisl_start_epoch + self.drop_unreliable_pixels_percent = drop_unreliable_pixels_percent + super().__init__( label_info=label_info, input_size=input_size, @@ -67,9 +93,75 @@ def __init__( scheduler=scheduler, metric=metric, torch_compile=torch_compile, + train_type=train_type, ) self.input_size: tuple[int, int] + def _create_model(self) -> nn.Module: + base_model = self._build_model() + if self.train_type == OTXTrainType.SEMI_SUPERVISED: + return MeanTeacher( + base_model, + unsup_weight=self.unsupervised_weight, + drop_unrel_pixels_percent=self.drop_unreliable_pixels_percent, + semisl_start_epoch=self.semisl_start_epoch, + ) + + return base_model + + @abstractmethod + def _build_model(self) -> nn.Module: + """Building base nn.Module model. + + Returns: + nn.Module: base nn.Module model for supervised training + """ + + def _customize_inputs(self, entity: SegBatchDataEntity) -> dict[str, Any]: + mode = "loss" if self.training else "predict" + + if self.train_type == OTXTrainType.SEMI_SUPERVISED and mode == "loss": + if not isinstance(entity, dict): + msg = "unlabeled inputs should be provided for semi-sl training" + raise RuntimeError(msg) + + return { + "inputs": entity["labeled"].images, + "unlabeled_weak_images": entity["weak_transforms"].images, + "unlabeled_strong_images": entity["strong_transforms"].images, + "global_step": self.trainer.global_step, + "steps_per_epoch": self.trainer.num_training_batches, + "img_metas": entity["labeled"].imgs_info, + "unlabeled_img_metas": entity["weak_transforms"].imgs_info, + "masks": torch.stack(entity["labeled"].masks).long(), + "mode": mode, + } + + masks = torch.stack(entity.masks).long() if mode == "loss" else None + return {"inputs": entity.images, "img_metas": entity.imgs_info, "masks": masks, "mode": mode} + + def _customize_outputs( + self, + outputs: Any, # noqa: ANN401 + inputs: SegBatchDataEntity, + ) -> SegBatchPredEntity | OTXBatchLossEntity: + if self.training: + if not isinstance(outputs, dict): + raise TypeError(outputs) + + losses = OTXBatchLossEntity() + for k, v in outputs.items(): + losses[k] = v + return losses + + return SegBatchPredEntity( + batch_size=len(outputs), + images=inputs.images, + imgs_info=inputs.imgs_info, + scores=[], + masks=outputs, + ) + @property def _export_parameters(self) -> TaskLevelExportParameters: """Defines parameters required to export a particular model implementation.""" @@ -81,6 +173,26 @@ def _export_parameters(self) -> TaskLevelExportParameters: blur_strength=-1, ) + @property + def _exporter(self) -> OTXModelExporter: + """Creates OTXModelExporter object that can export the model.""" + if self.input_size is None: + msg = f"Image size attribute is not set for {self.__class__}" + raise ValueError(msg) + + return OTXNativeModelExporter( + task_level_export_parameters=self._export_parameters, + input_size=(1, 3, *self.input_size), + mean=self.mean, + std=self.scale, + resize_mode="standard", + pad_value=0, + swap_rgb=False, + via_onnx=False, + onnx_export_configuration=None, + output_names=None, + ) + def _convert_pred_entity_to_compute_metric( self, preds: SegBatchPredEntity, @@ -147,120 +259,12 @@ def export( Returns: Path: path to the exported model. """ - if hasattr(self.model, "student_model"): - # use only teacher model - # TODO(Kirill): make this based on the training type + if self.train_type == OTXTrainType.SEMI_SUPERVISED: + # use only teacher model for deployment self.model = self.model.teacher_model return super().export(output_dir, base_name, export_format, precision, to_exportable_code) -class TorchVisionCompatibleModel(OTXSegmentationModel): - """Segmentation model compatible with torchvision data pipeline.""" - - def __init__( - self, - label_info: LabelInfoTypes, - input_size: tuple[int, int], - optimizer: OptimizerCallable = DefaultOptimizerCallable, - scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable, - metric: MetricCallable = SegmCallable, # type: ignore[assignment] - torch_compile: bool = False, - backbone_configuration: dict[str, Any] | None = None, - decode_head_configuration: dict[str, Any] | None = None, - criterion_configuration: list[dict[str, Any]] | None = None, - export_image_configuration: dict[str, Any] | None = None, - name_base_model: str = "semantic_segmentation_model", - ): - """Torchvision compatible model. - - Args: - label_info (LabelInfoTypes): The label information for the segmentation model. - input_size (tuple[int, int]): Model input size in the order of height and width. - optimizer (OptimizerCallable, optional): The optimizer callable for the model. - Defaults to DefaultOptimizerCallable. - scheduler (LRSchedulerCallable | LRSchedulerListCallable, optional): - The learning rate scheduler callable for the model. Defaults to DefaultSchedulerCallable. - metric (MetricCallable, optional): The metric callable for the model. - Defaults to SegmCallable. - torch_compile (bool, optional): Whether to compile the model using Torch. Defaults to False. - backbone_configuration (dict[str, Any] | None, optional): - The configuration for the backbone of the model. Defaults to None. - decode_head_configuration (dict[str, Any] | None, optional): - The configuration for the decode head of the model. Defaults to None. - criterion_configuration (list[dict[str, Any]] | None, optional): - The configuration for the criterion of the model. Defaults to None. - export_image_configuration (dict[str, Any] | None, optional): - The configuration for the export of the model like mean and scale. Defaults to None. - name_base_model (str, optional): The name of the base model used for trainig. - Defaults to "semantic_segmentation_model". - """ - self.backbone_configuration = backbone_configuration if backbone_configuration is not None else {} - self.decode_head_configuration = decode_head_configuration if decode_head_configuration is not None else {} - export_image_configuration = export_image_configuration if export_image_configuration is not None else {} - self.criterion_configuration = criterion_configuration - self.mean = export_image_configuration.get("mean", [123.675, 116.28, 103.53]) - self.scale = export_image_configuration.get("std", [58.395, 57.12, 57.375]) - self.name_base_model = name_base_model - - super().__init__( - label_info=label_info, - input_size=input_size, - optimizer=optimizer, - scheduler=scheduler, - metric=metric, - torch_compile=torch_compile, - ) - - def _customize_inputs(self, entity: SegBatchDataEntity) -> dict[str, Any]: - mode = "loss" if self.training else "predict" - - masks = torch.stack(entity.masks).long() if mode == "loss" else None - - return {"inputs": entity.images, "img_metas": entity.imgs_info, "masks": masks, "mode": mode} - - def _customize_outputs( - self, - outputs: Any, # noqa: ANN401 - inputs: SegBatchDataEntity, - ) -> SegBatchPredEntity | OTXBatchLossEntity: - if self.training: - if not isinstance(outputs, dict): - raise TypeError(outputs) - - losses = OTXBatchLossEntity() - for k, v in outputs.items(): - losses[k] = v - return losses - - return SegBatchPredEntity( - batch_size=len(outputs), - images=inputs.images, - imgs_info=inputs.imgs_info, - scores=[], - masks=outputs, - ) - - @property - def _exporter(self) -> OTXModelExporter: - """Creates OTXModelExporter object that can export the model.""" - if self.input_size is None: - msg = f"Input size attribute is not set for {self.__class__}" - raise ValueError(msg) - - return OTXNativeModelExporter( - task_level_export_parameters=self._export_parameters, - input_size=(1, 3, *self.input_size), - mean=self.mean, - std=self.scale, - resize_mode="standard", - pad_value=0, - swap_rgb=False, - via_onnx=False, - onnx_export_configuration=None, - output_names=None, - ) - - class OVSegmentationModel(OVModel[SegBatchDataEntity, SegBatchPredEntity]): """Semantic segmentation model compatible for OpenVINO IR inference. diff --git a/src/otx/recipe/semantic_segmentation/dino_v2.yaml b/src/otx/recipe/semantic_segmentation/dino_v2.yaml index 984e858860d..713b8e92624 100644 --- a/src/otx/recipe/semantic_segmentation/dino_v2.yaml +++ b/src/otx/recipe/semantic_segmentation/dino_v2.yaml @@ -1,12 +1,11 @@ model: - class_path: otx.algo.segmentation.dino_v2_seg.OTXDinoV2Seg + class_path: otx.algo.segmentation.dino_v2_seg.DinoV2Seg init_args: label_info: 2 - - criterion_configuration: - - type: CrossEntropyLoss - params: - ignore_index: 255 + model_version: dinov2_vits14 + input_size: + - 560 + - 560 optimizer: class_path: torch.optim.AdamW diff --git a/src/otx/recipe/semantic_segmentation/litehrnet_18.yaml b/src/otx/recipe/semantic_segmentation/litehrnet_18.yaml index 9156bd38a11..e7a20d7e369 100644 --- a/src/otx/recipe/semantic_segmentation/litehrnet_18.yaml +++ b/src/otx/recipe/semantic_segmentation/litehrnet_18.yaml @@ -1,13 +1,8 @@ model: - class_path: otx.algo.segmentation.litehrnet.OTXLiteHRNet + class_path: otx.algo.segmentation.litehrnet.LiteHRNet init_args: label_info: 2 - name_base_model: LiteHRNet18 - - criterion_configuration: - - type: CrossEntropyLoss - params: - ignore_index: 255 + model_version: lite_hrnet_18 optimizer: class_path: torch.optim.Adam diff --git a/src/otx/recipe/semantic_segmentation/litehrnet_s.yaml b/src/otx/recipe/semantic_segmentation/litehrnet_s.yaml index a62938480bd..d353ffdfc4c 100644 --- a/src/otx/recipe/semantic_segmentation/litehrnet_s.yaml +++ b/src/otx/recipe/semantic_segmentation/litehrnet_s.yaml @@ -1,13 +1,8 @@ model: - class_path: otx.algo.segmentation.litehrnet.OTXLiteHRNet + class_path: otx.algo.segmentation.litehrnet.LiteHRNet init_args: label_info: 2 - name_base_model: LiteHRNetS - - criterion_configuration: - - type: CrossEntropyLoss - params: - ignore_index: 255 + model_version: lite_hrnet_s optimizer: class_path: torch.optim.Adam diff --git a/src/otx/recipe/semantic_segmentation/litehrnet_x.yaml b/src/otx/recipe/semantic_segmentation/litehrnet_x.yaml index 100edf1d8b2..85bb55d55ca 100644 --- a/src/otx/recipe/semantic_segmentation/litehrnet_x.yaml +++ b/src/otx/recipe/semantic_segmentation/litehrnet_x.yaml @@ -1,13 +1,8 @@ model: - class_path: otx.algo.segmentation.litehrnet.OTXLiteHRNet + class_path: otx.algo.segmentation.litehrnet.LiteHRNet init_args: label_info: 2 - name_base_model: LiteHRNetX - - criterion_configuration: - - type: CrossEntropyLoss - params: - ignore_index: 255 + model_version: lite_hrnet_x optimizer: class_path: torch.optim.Adam diff --git a/src/otx/recipe/semantic_segmentation/segnext_b.yaml b/src/otx/recipe/semantic_segmentation/segnext_b.yaml index 62cced98c53..49626e58d6c 100644 --- a/src/otx/recipe/semantic_segmentation/segnext_b.yaml +++ b/src/otx/recipe/semantic_segmentation/segnext_b.yaml @@ -1,13 +1,8 @@ model: - class_path: otx.algo.segmentation.segnext.OTXSegNext + class_path: otx.algo.segmentation.segnext.SegNext init_args: label_info: 2 - name_base_model: SegNextB - - criterion_configuration: - - type: CrossEntropyLoss - params: - ignore_index: 255 + model_version: segnext_base optimizer: class_path: torch.optim.AdamW diff --git a/src/otx/recipe/semantic_segmentation/segnext_s.yaml b/src/otx/recipe/semantic_segmentation/segnext_s.yaml index 8686002822c..e8eab1d22e7 100644 --- a/src/otx/recipe/semantic_segmentation/segnext_s.yaml +++ b/src/otx/recipe/semantic_segmentation/segnext_s.yaml @@ -1,13 +1,8 @@ model: - class_path: otx.algo.segmentation.segnext.OTXSegNext + class_path: otx.algo.segmentation.segnext.SegNext init_args: label_info: 2 - name_base_model: SegNextS - - criterion_configuration: - - type: CrossEntropyLoss - params: - ignore_index: 255 + model_version: segnext_small optimizer: class_path: torch.optim.AdamW diff --git a/src/otx/recipe/semantic_segmentation/segnext_t.yaml b/src/otx/recipe/semantic_segmentation/segnext_t.yaml index 621c827f334..755c26ee49c 100644 --- a/src/otx/recipe/semantic_segmentation/segnext_t.yaml +++ b/src/otx/recipe/semantic_segmentation/segnext_t.yaml @@ -1,13 +1,8 @@ model: - class_path: otx.algo.segmentation.segnext.OTXSegNext + class_path: otx.algo.segmentation.segnext.SegNext init_args: label_info: 2 - name_base_model: SegNextT - - criterion_configuration: - - type: CrossEntropyLoss - params: - ignore_index: 255 + model_version: segnext_tiny optimizer: class_path: torch.optim.AdamW diff --git a/src/otx/recipe/semantic_segmentation/semisl/dino_v2_semisl.yaml b/src/otx/recipe/semantic_segmentation/semisl/dino_v2_semisl.yaml index 8902a549cde..7dc5ece097c 100644 --- a/src/otx/recipe/semantic_segmentation/semisl/dino_v2_semisl.yaml +++ b/src/otx/recipe/semantic_segmentation/semisl/dino_v2_semisl.yaml @@ -1,12 +1,12 @@ model: - class_path: otx.algo.segmentation.dino_v2_seg.DinoV2SegSemiSL + class_path: otx.algo.segmentation.dino_v2_seg.DinoV2Seg init_args: label_info: 2 - - criterion_configuration: - - type: CrossEntropyLoss - params: - ignore_index: 255 + model_version: dinov2_vits14 + train_type: SEMI_SUPERVISED + input_size: + - 560 + - 560 optimizer: class_path: torch.optim.AdamW @@ -17,13 +17,6 @@ model: - 0.999 weight_decay: 0.0001 - export_image_configuration: - image_size: - - 1 - - 3 - - 560 - - 560 - scheduler: class_path: torch.optim.lr_scheduler.PolynomialLR init_args: diff --git a/src/otx/recipe/semantic_segmentation/semisl/litehrnet_18_semisl.yaml b/src/otx/recipe/semantic_segmentation/semisl/litehrnet_18_semisl.yaml index 8a6c02a5fd4..a98f1ab47a2 100644 --- a/src/otx/recipe/semantic_segmentation/semisl/litehrnet_18_semisl.yaml +++ b/src/otx/recipe/semantic_segmentation/semisl/litehrnet_18_semisl.yaml @@ -1,13 +1,9 @@ model: - class_path: otx.algo.segmentation.litehrnet.LiteHRNetSemiSL + class_path: otx.algo.segmentation.litehrnet.LiteHRNet init_args: label_info: 2 - name_base_model: LiteHRNet18 - - criterion_configuration: - - type: CrossEntropyLoss - params: - ignore_index: 255 + model_version: lite_hrnet_18 + train_type: SEMI_SUPERVISED optimizer: class_path: torch.optim.Adam diff --git a/src/otx/recipe/semantic_segmentation/semisl/litehrnet_s_semisl.yaml b/src/otx/recipe/semantic_segmentation/semisl/litehrnet_s_semisl.yaml index c34bf5436a4..c0cd0de594f 100644 --- a/src/otx/recipe/semantic_segmentation/semisl/litehrnet_s_semisl.yaml +++ b/src/otx/recipe/semantic_segmentation/semisl/litehrnet_s_semisl.yaml @@ -1,13 +1,9 @@ model: - class_path: otx.algo.segmentation.litehrnet.LiteHRNetSemiSL + class_path: otx.algo.segmentation.litehrnet.LiteHRNet init_args: label_info: 2 - name_base_model: LiteHRNetS - - criterion_configuration: - - type: CrossEntropyLoss - params: - ignore_index: 255 + model_version: lite_hrnet_s + train_type: SEMI_SUPERVISED optimizer: class_path: torch.optim.Adam diff --git a/src/otx/recipe/semantic_segmentation/semisl/litehrnet_x_semisl.yaml b/src/otx/recipe/semantic_segmentation/semisl/litehrnet_x_semisl.yaml index a5f6e8f0606..ab757f65887 100644 --- a/src/otx/recipe/semantic_segmentation/semisl/litehrnet_x_semisl.yaml +++ b/src/otx/recipe/semantic_segmentation/semisl/litehrnet_x_semisl.yaml @@ -1,13 +1,9 @@ model: - class_path: otx.algo.segmentation.litehrnet.LiteHRNetSemiSL + class_path: otx.algo.segmentation.litehrnet.LiteHRNet init_args: label_info: 2 - name_base_model: LiteHRNetX - - criterion_configuration: - - type: CrossEntropyLoss - params: - ignore_index: 255 + model_version: lite_hrnet_x + train_type: SEMI_SUPERVISED optimizer: class_path: torch.optim.Adam diff --git a/src/otx/recipe/semantic_segmentation/semisl/segnext_b_semisl.yaml b/src/otx/recipe/semantic_segmentation/semisl/segnext_b_semisl.yaml index d8557a58465..395d0fb5c5e 100644 --- a/src/otx/recipe/semantic_segmentation/semisl/segnext_b_semisl.yaml +++ b/src/otx/recipe/semantic_segmentation/semisl/segnext_b_semisl.yaml @@ -1,13 +1,9 @@ model: - class_path: otx.algo.segmentation.segnext.OTXSegNext + class_path: otx.algo.segmentation.segnext.SegNext init_args: label_info: 2 - name_base_model: SegNextB - - criterion_configuration: - - type: CrossEntropyLoss - params: - ignore_index: 255 + model_version: segnext_base + train_type: SEMI_SUPERVISED optimizer: class_path: torch.optim.AdamW @@ -36,10 +32,7 @@ engine: callback_monitor: val/Dice data: ../../_base_/data/semisl/semantic_segmentation_semisl.yaml - overrides: - model: - class_path: otx.algo.segmentation.segnext.SemiSLSegNext callbacks: - class_path: otx.algo.callbacks.adaptive_early_stopping.EarlyStoppingWithWarmup init_args: diff --git a/src/otx/recipe/semantic_segmentation/semisl/segnext_s_semisl.yaml b/src/otx/recipe/semantic_segmentation/semisl/segnext_s_semisl.yaml index 231587d4835..8748572ed6b 100644 --- a/src/otx/recipe/semantic_segmentation/semisl/segnext_s_semisl.yaml +++ b/src/otx/recipe/semantic_segmentation/semisl/segnext_s_semisl.yaml @@ -1,13 +1,9 @@ model: - class_path: otx.algo.segmentation.segnext.SemiSLSegNext + class_path: otx.algo.segmentation.segnext.SegNext init_args: label_info: 2 - name_base_model: SegNextS - - criterion_configuration: - - type: CrossEntropyLoss - params: - ignore_index: 255 + model_version: segnext_small + train_type: SEMI_SUPERVISED optimizer: class_path: torch.optim.AdamW diff --git a/src/otx/recipe/semantic_segmentation/semisl/segnext_t_semisl.yaml b/src/otx/recipe/semantic_segmentation/semisl/segnext_t_semisl.yaml index 6152665bb0a..b6b884b2759 100644 --- a/src/otx/recipe/semantic_segmentation/semisl/segnext_t_semisl.yaml +++ b/src/otx/recipe/semantic_segmentation/semisl/segnext_t_semisl.yaml @@ -1,13 +1,9 @@ model: - class_path: otx.algo.segmentation.segnext.SemiSLSegNext + class_path: otx.algo.segmentation.segnext.SegNext init_args: label_info: 2 - name_base_model: SegNextT - - criterion_configuration: - - type: CrossEntropyLoss - params: - ignore_index: 255 + model_version: segnext_tiny + train_type: SEMI_SUPERVISED optimizer: class_path: torch.optim.AdamW @@ -36,6 +32,7 @@ engine: callback_monitor: val/Dice data: ../../_base_/data/semisl/semantic_segmentation_semisl.yaml + overrides: callbacks: - class_path: otx.algo.callbacks.adaptive_early_stopping.EarlyStoppingWithWarmup diff --git a/tests/unit/algo/classification/test_torchvision_model.py b/tests/unit/algo/classification/test_torchvision_model.py index d68374ed731..7c8dec308bf 100644 --- a/tests/unit/algo/classification/test_torchvision_model.py +++ b/tests/unit/algo/classification/test_torchvision_model.py @@ -8,7 +8,7 @@ from otx.core.data.entity.base import OTXBatchLossEntity, OTXBatchPredEntity from otx.core.data.entity.classification import MulticlassClsBatchPredEntity from otx.core.types.export import TaskLevelExportParameters -from otx.core.types.task import OTXTaskType +from otx.core.types.task import OTXTaskType, OTXTrainType @pytest.fixture() @@ -45,7 +45,12 @@ class TestOTXTVModel: def test_create_model(self, fxt_tv_model): assert isinstance(fxt_tv_model.model, TVClassificationModel) - semi_sl_model = OTXTVModel(backbone="mobilenet_v3_small", label_info=10, train_type="SEMI_SUPERVISED") + semi_sl_model = OTXTVModel( + backbone="mobilenet_v3_small", + label_info=10, + train_type=OTXTrainType.SEMI_SUPERVISED, + task=OTXTaskType.MULTI_CLASS_CLS, + ) assert isinstance(semi_sl_model.model.head, OTXSemiSLLinearClsHead) @pytest.mark.parametrize( diff --git a/tests/unit/algo/segmentation/backbones/test_dinov2.py b/tests/unit/algo/segmentation/backbones/test_dinov2.py index 8774767f61a..0e5f920d67e 100644 --- a/tests/unit/algo/segmentation/backbones/test_dinov2.py +++ b/tests/unit/algo/segmentation/backbones/test_dinov2.py @@ -30,23 +30,18 @@ def mock_torch_hub_load(self, mocker, mock_backbone): return mocker.patch("otx.algo.segmentation.backbones.dinov2.torch.hub.load", return_value=mock_backbone) def test_init(self, mock_backbone, mock_backbone_named_parameters): - dino = DinoVisionTransformer(name="dinov2_vits14_reg", freeze_backbone=True, out_index=[8, 9, 10, 11]) + dino = DinoVisionTransformer(name="dinov2_vits14", freeze_backbone=True, out_index=[8, 9, 10, 11]) assert dino.backbone == mock_backbone for parameter in mock_backbone_named_parameters.values(): assert parameter.requires_grad is False @pytest.fixture() - def mock_init_cfg(self) -> MagicMock: - return MagicMock() - - @pytest.fixture() - def dino_vit(self, mock_init_cfg) -> DinoVisionTransformer: + def dino_vit(self) -> DinoVisionTransformer: return DinoVisionTransformer( - name="dinov2_vits14_reg", + name="dinov2_vits14", freeze_backbone=True, out_index=[8, 9, 10, 11], - init_cfg=mock_init_cfg, ) def test_forward(self, dino_vit, mock_backbone): diff --git a/tests/unit/algo/segmentation/backbones/test_litehrnet.py b/tests/unit/algo/segmentation/backbones/test_litehrnet.py index eddac529ed0..32242fc3549 100644 --- a/tests/unit/algo/segmentation/backbones/test_litehrnet.py +++ b/tests/unit/algo/segmentation/backbones/test_litehrnet.py @@ -1,11 +1,8 @@ -from copy import deepcopy -from pathlib import Path from unittest.mock import MagicMock import pytest import torch -from otx.algo.segmentation.backbones import litehrnet as target_file -from otx.algo.segmentation.backbones.litehrnet import LiteHRNet, NeighbourSupport, SpatialWeightingV2, StemV2 +from otx.algo.segmentation.backbones.litehrnet import NeighbourSupport, NNLiteHRNet, SpatialWeightingV2, StemV2 class TestSpatialWeightingV2: @@ -52,9 +49,9 @@ def test_forward(self) -> None: assert outputs is not None -class TestLiteHRNet: +class TestNNLiteHRNet: @pytest.fixture() - def extra_cfg(self) -> dict: + def cfg(self) -> dict: return { "stem": { "stem_channels": 32, @@ -78,65 +75,23 @@ def extra_cfg(self) -> dict: (40, 80, 160, 320), ], }, - "out_modules": { - "conv": { - "enable": True, - "channels": 320, - }, - "position_att": { - "enable": True, - "key_channels": 128, - "value_channels": 320, - "psp_size": [1, 3, 6, 8], - }, - "local_att": { - "enable": False, - }, - }, } @pytest.fixture() - def backbone(self, extra_cfg) -> LiteHRNet: - return LiteHRNet(extra=extra_cfg) - - def test_init(self, extra_cfg) -> None: - extra = deepcopy(extra_cfg) + def backbone(self, cfg) -> NNLiteHRNet: + return NNLiteHRNet(**cfg) - extra["add_stem_features"] = True - model = LiteHRNet(extra=extra) - assert model is not None - - extra["stages_spec"]["module_type"] = ("NAIVE", "NAIVE", "NAIVE") - extra["stages_spec"]["weighting_module_version"] = "v2" - model = LiteHRNet(extra=extra) - assert model is not None - - def test_init_weights(self, backbone) -> None: - backbone.init_weights() - - with pytest.raises(TypeError): - backbone.init_weights(0) - - def test_forward(self, extra_cfg, backbone) -> None: - backbone.train() - inputs = torch.randn((1, 3, 224, 224)) - outputs = backbone(inputs) - assert outputs is not None - - extra = deepcopy(extra_cfg) - extra["stages_spec"]["module_type"] = ("NAIVE", "NAIVE", "NAIVE") - extra["stages_spec"]["weighting_module_version"] = "v2" - model = LiteHRNet(extra=extra) - outputs = model(inputs) - assert outputs is not None + @pytest.fixture() + def mock_torch_load(self, mocker) -> MagicMock: + return mocker.patch("otx.algo.segmentation.backbones.litehrnet.torch.load") @pytest.fixture() def mock_load_from_http(self, mocker) -> MagicMock: - return mocker.patch.object(target_file, "load_from_http") + return mocker.patch("otx.algo.segmentation.backbones.litehrnet.load_from_http") @pytest.fixture() def mock_load_checkpoint_to_model(self, mocker) -> MagicMock: - return mocker.patch.object(target_file, "load_checkpoint_to_model") + return mocker.patch("otx.algo.segmentation.backbones.litehrnet.load_checkpoint_to_model") @pytest.fixture() def pretrained_weight(self, tmp_path) -> str: @@ -144,30 +99,36 @@ def pretrained_weight(self, tmp_path) -> str: weight.touch() return str(weight) - @pytest.fixture() - def mock_torch_load(self, mocker) -> MagicMock: - return mocker.patch("otx.algo.segmentation.backbones.mscan.torch.load") + def test_init(self, cfg) -> None: + model = NNLiteHRNet(**cfg) + assert model is not None + + def test_forward(self, cfg, backbone) -> None: + backbone.train() + inputs = torch.randn((1, 3, 224, 224)) + outputs = backbone(inputs) + assert outputs is not None + + def test_load_pretrained_weights_from_url( + self, + mock_load_from_http, + mock_load_checkpoint_to_model, + backbone, + ) -> None: + pretrained_weight = "www.fake.com/fake.pth" + backbone.load_pretrained_weights(pretrained=pretrained_weight) + mock_load_from_http.assert_called_once() + mock_load_checkpoint_to_model.assert_called_once() def test_load_pretrained_weights( self, - extra_cfg, + cfg, pretrained_weight, mock_torch_load, mock_load_checkpoint_to_model, ): - extra_cfg["add_stem_features"] = True - model = LiteHRNet(extra=extra_cfg) + model = NNLiteHRNet(**cfg) model.load_pretrained_weights(pretrained=pretrained_weight) mock_torch_load.assert_called_once_with(pretrained_weight, "cpu") mock_load_checkpoint_to_model.assert_called_once() - - def test_load_pretrained_weights_from_url(self, extra_cfg, mock_load_from_http, mock_load_checkpoint_to_model): - pretrained_weight = "www.fake.com/fake.pth" - extra_cfg["add_stem_features"] = True - model = LiteHRNet(extra=extra_cfg) - model.load_pretrained_weights(pretrained=pretrained_weight) - - cache_dir = Path.home() / ".cache" / "torch" / "hub" / "checkpoints" - mock_load_from_http.assert_called_once_with(filename=pretrained_weight, map_location="cpu", model_dir=cache_dir) - mock_load_checkpoint_to_model.assert_called_once() diff --git a/tests/unit/algo/segmentation/backbones/test_mscan.py b/tests/unit/algo/segmentation/backbones/test_mscan.py index a991b9ba8c2..441e121ead1 100644 --- a/tests/unit/algo/segmentation/backbones/test_mscan.py +++ b/tests/unit/algo/segmentation/backbones/test_mscan.py @@ -4,7 +4,7 @@ import pytest import torch from otx.algo.segmentation.backbones import mscan as target_file -from otx.algo.segmentation.backbones.mscan import MSCAN, DropPath, drop_path +from otx.algo.segmentation.backbones.mscan import NNMSCAN, DropPath, drop_path @pytest.mark.parametrize("dim", [1, 2, 3, 4]) @@ -59,7 +59,7 @@ def test_forward(self): class TestMSCABlock: def test_init(self): num_stages = 4 - mscan = MSCAN(num_stages=num_stages) + mscan = NNMSCAN(num_stages=num_stages) for i in range(num_stages): assert hasattr(mscan, f"patch_embed{i + 1}") @@ -68,7 +68,7 @@ def test_init(self): def test_forward(self): num_stages = 4 - mscan = MSCAN(num_stages=num_stages) + mscan = NNMSCAN(num_stages=num_stages) x = torch.rand(8, 3, 3, 3) out = mscan.forward(x) @@ -93,14 +93,14 @@ def mock_torch_load(self, mocker) -> MagicMock: return mocker.patch("otx.algo.segmentation.backbones.mscan.torch.load") def test_load_pretrained_weights(self, pretrained_weight, mock_torch_load, mock_load_checkpoint_to_model): - MSCAN(pretrained_weights=pretrained_weight) + NNMSCAN(pretrained_weights=pretrained_weight) mock_torch_load.assert_called_once_with(pretrained_weight, "cpu") mock_load_checkpoint_to_model.assert_called_once() def test_load_pretrained_weights_from_url(self, mock_load_from_http, mock_load_checkpoint_to_model): pretrained_weight = "www.fake.com/fake.pth" - MSCAN(pretrained_weights=pretrained_weight) + NNMSCAN(pretrained_weights=pretrained_weight) cache_dir = Path.home() / ".cache" / "torch" / "hub" / "checkpoints" mock_load_from_http.assert_called_once_with(filename=pretrained_weight, map_location="cpu", model_dir=cache_dir) diff --git a/tests/unit/algo/segmentation/heads/test_class_incremental.py b/tests/unit/algo/segmentation/heads/test_class_incremental.py index bc4e6fd285d..2ea0bf1713c 100644 --- a/tests/unit/algo/segmentation/heads/test_class_incremental.py +++ b/tests/unit/algo/segmentation/heads/test_class_incremental.py @@ -4,7 +4,7 @@ import torch -from otx.algo.segmentation.litehrnet import OTXLiteHRNet +from otx.algo.segmentation.litehrnet import LiteHRNet from otx.core.data.entity.base import ImageInfo @@ -14,7 +14,7 @@ class MockGT: class TestClassIncrementalMixin: def test_ignore_label(self) -> None: - hrnet = OTXLiteHRNet(3, name_base_model="LiteHRNet18") + hrnet = LiteHRNet(3, input_size=(128, 128), model_version="lite_hrnet_18") seg_logits = torch.randn(1, 3, 128, 128) # no annotations for class=3 diff --git a/tests/unit/algo/segmentation/heads/test_ham_head.py b/tests/unit/algo/segmentation/heads/test_ham_head.py index d3ed05b40da..1edcddf813d 100644 --- a/tests/unit/algo/segmentation/heads/test_ham_head.py +++ b/tests/unit/algo/segmentation/heads/test_ham_head.py @@ -4,10 +4,10 @@ import pytest import torch -from otx.algo.segmentation.heads.ham_head import LightHamHead +from otx.algo.segmentation.heads.ham_head import NNLightHamHead -class TestLightHamHead: +class TestNNLightHamHead: @pytest.fixture() def head_config(self) -> dict[str, Any]: return { @@ -23,7 +23,7 @@ def head_config(self) -> dict[str, Any]: } def test_init(self, head_config): - light_ham_head = LightHamHead(**head_config) + light_ham_head = NNLightHamHead(**head_config) assert light_ham_head.ham_channels == head_config["ham_channels"] @pytest.fixture() @@ -40,7 +40,7 @@ def fake_input(self, batch_size) -> list[torch.Tensor]: ] def test_forward(self, head_config, fake_input, batch_size): - light_ham_head = LightHamHead(**head_config) + light_ham_head = NNLightHamHead(**head_config) out = light_ham_head.forward(fake_input) assert out.size()[0] == batch_size assert out.size()[2] == fake_input[head_config["in_index"][0]].size()[2] diff --git a/tests/unit/algo/segmentation/modules/__init__.py b/tests/unit/algo/segmentation/modules/__init__.py deleted file mode 100644 index 39d45d9ed72..00000000000 --- a/tests/unit/algo/segmentation/modules/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -# Copyright (C) 2024 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 -"""Test of custom algo modules of OTX segmentation task.""" diff --git a/tests/unit/algo/segmentation/modules/test_blokcs.py b/tests/unit/algo/segmentation/modules/test_blokcs.py deleted file mode 100644 index 728d85169a0..00000000000 --- a/tests/unit/algo/segmentation/modules/test_blokcs.py +++ /dev/null @@ -1,62 +0,0 @@ -from __future__ import annotations - -from typing import Any - -import pytest -import torch -from otx.algo.segmentation.modules.blocks import AsymmetricPositionAttentionModule, LocalAttentionModule - - -class TestAsymmetricPositionAttentionModule: - @pytest.fixture() - def init_cfg(self) -> dict[str, Any]: - return { - "in_channels": 320, - "key_channels": 128, - "value_channels": 320, - "psp_size": [1, 3, 6, 8], - "norm_cfg": {"type": "BN"}, - } - - def test_init(self, init_cfg): - module = AsymmetricPositionAttentionModule(**init_cfg) - - assert module.in_channels == init_cfg["in_channels"] - assert module.key_channels == init_cfg["key_channels"] - assert module.value_channels == init_cfg["value_channels"] - assert module.norm_cfg == init_cfg["norm_cfg"] - - @pytest.fixture() - def fake_input(self) -> torch.Tensor: - return torch.rand(8, 320, 16, 16) - - def test_forward(self, init_cfg, fake_input): - module = AsymmetricPositionAttentionModule(**init_cfg) - out = module.forward(fake_input) - - assert out.size() == fake_input.size() - - -class TestLocalAttentionModule: - @pytest.fixture() - def init_cfg(self) -> dict[str, Any]: - return { - "num_channels": 320, - "norm_cfg": {"type": "BN"}, - } - - def test_init(self, init_cfg): - module = LocalAttentionModule(**init_cfg) - - assert module.num_channels == init_cfg["num_channels"] - assert module.norm_cfg == init_cfg["norm_cfg"] - - @pytest.fixture() - def fake_input(self) -> torch.Tensor: - return torch.rand(8, 320, 16, 16) - - def test_forward(self, init_cfg, fake_input): - module = LocalAttentionModule(**init_cfg) - - out = module.forward(fake_input) - assert out.size() == fake_input.size() diff --git a/tests/unit/algo/segmentation/segmentors/test_base_model.py b/tests/unit/algo/segmentation/segmentors/test_base_model.py new file mode 100644 index 00000000000..33a33af4dda --- /dev/null +++ b/tests/unit/algo/segmentation/segmentors/test_base_model.py @@ -0,0 +1,64 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +import pytest +import torch +from otx.algo.segmentation.segmentors.base_model import BaseSegmModel +from otx.core.data.entity.base import ImageInfo + + +class TestBaseSegmModel: + @pytest.fixture() + def model(self): + backbone = torch.nn.Sequential(torch.nn.Conv2d(3, 64, kernel_size=3, padding=1)) + decode_head = torch.nn.Sequential(torch.nn.Conv2d(64, 2, kernel_size=1)) + decode_head.num_classes = 3 + return BaseSegmModel(backbone, decode_head) + + @pytest.fixture() + def inputs(self): + inputs = torch.randn(1, 3, 256, 256) + masks = torch.randint(0, 2, (1, 256, 256)) + return inputs, masks + + def test_forward_returns_tensor(self, model, inputs): + images = inputs[0] + output = model.forward(images) + assert isinstance(output, torch.Tensor) + + def test_forward_returns_loss(self, model, inputs): + model.criterion.name = "CrossEntropyLoss" + images, masks = inputs + img_metas = [ImageInfo(img_shape=(256, 256), img_idx=0, ori_shape=(256, 256))] + output = model.forward(images, img_metas=img_metas, masks=masks, mode="loss") + assert isinstance(output, dict) + assert "CrossEntropyLoss" in output + + def test_forward_returns_prediction(self, model, inputs): + images = inputs[0] + output = model.forward(images, mode="predict") + assert isinstance(output, torch.Tensor) + assert output.shape == (1, 256, 256) + + def test_extract_features(self, model, inputs): + images = inputs[0] + features = model.extract_features(images) + assert isinstance(features, torch.Tensor) + assert features.shape == (1, 2, 256, 256) + + def test_calculate_loss(self, model, inputs): + model.criterion.name = "CrossEntropyLoss" + images, masks = inputs + img_metas = [ImageInfo(img_shape=(256, 256), img_idx=0, ori_shape=(256, 256))] + loss = model.calculate_loss(images, img_metas, masks, interpolate=False) + assert isinstance(loss, dict) + assert "CrossEntropyLoss" in loss + assert isinstance(loss["CrossEntropyLoss"], torch.Tensor) + + def test_get_valid_label_mask(self, model): + img_metas = [ImageInfo(img_shape=(256, 256), img_idx=0, ignored_labels=[0, 2], ori_shape=(256, 256))] + valid_label_mask = model.get_valid_label_mask(img_metas) + assert isinstance(valid_label_mask, list) + assert len(valid_label_mask) == 1 + assert isinstance(valid_label_mask[0], torch.Tensor) + assert valid_label_mask[0].tolist() == [0, 1, 0] diff --git a/tests/unit/algo/segmentation/segmentors/test_mean_teacher.py b/tests/unit/algo/segmentation/segmentors/test_mean_teacher.py index 47ce0c32eb1..f7be592b5fa 100644 --- a/tests/unit/algo/segmentation/segmentors/test_mean_teacher.py +++ b/tests/unit/algo/segmentation/segmentors/test_mean_teacher.py @@ -1,5 +1,10 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + import pytest import torch +from otx.algo.segmentation.losses import CrossEntropyLossWithIgnore from otx.algo.segmentation.segmentors import BaseSegmModel, MeanTeacher from otx.core.data.entity.base import ImageInfo from torch import nn @@ -10,9 +15,11 @@ class TestMeanTeacher: def model(self): decode_head = nn.Conv2d(3, 2, 1) decode_head.num_classes = 2 + loss = CrossEntropyLossWithIgnore(ignore_index=255) model = BaseSegmModel( backbone=nn.Sequential(nn.Conv2d(3, 5, 1), nn.ReLU(), nn.Conv2d(5, 3, 1)), decode_head=decode_head, + criterion=loss, ) return MeanTeacher(model) diff --git a/tests/unit/algo/segmentation/test_dino_v2_seg.py b/tests/unit/algo/segmentation/test_dino_v2_seg.py index 259b6f4816b..5353a43616a 100644 --- a/tests/unit/algo/segmentation/test_dino_v2_seg.py +++ b/tests/unit/algo/segmentation/test_dino_v2_seg.py @@ -1,19 +1,19 @@ # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - +# import pytest -from otx.algo.segmentation.dino_v2_seg import OTXDinoV2Seg +from otx.algo.segmentation.dino_v2_seg import DinoV2Seg from otx.core.exporter.base import OTXModelExporter class TestDinoV2Seg: @pytest.fixture(scope="class") - def fxt_dino_v2_seg(self) -> OTXDinoV2Seg: - return OTXDinoV2Seg(label_info=10) + def fxt_dino_v2_seg(self) -> DinoV2Seg: + return DinoV2Seg(label_info=10, model_version="dinov2_vits14", input_size=(560, 560)) def test_dino_v2_seg_init(self, fxt_dino_v2_seg): - assert isinstance(fxt_dino_v2_seg, OTXDinoV2Seg) + assert isinstance(fxt_dino_v2_seg, DinoV2Seg) assert fxt_dino_v2_seg.num_classes == 10 def test_exporter(self, fxt_dino_v2_seg): diff --git a/tests/unit/algo/segmentation/test_segnext.py b/tests/unit/algo/segmentation/test_segnext.py index 88f807aacfd..375ad9d0b61 100644 --- a/tests/unit/algo/segmentation/test_segnext.py +++ b/tests/unit/algo/segmentation/test_segnext.py @@ -3,17 +3,17 @@ import pytest -from otx.algo.segmentation.segnext import OTXSegNext +from otx.algo.segmentation.segnext import SegNext from otx.algo.utils.support_otx_v1 import OTXv1Helper class TestSegNext: @pytest.fixture() - def fxt_segnext(self) -> OTXSegNext: - return OTXSegNext(10, name_base_model="SegNextB") + def fxt_segnext(self) -> SegNext: + return SegNext(10, model_version="segnext_base", input_size=(512, 512)) def test_segnext_init(self, fxt_segnext): - assert isinstance(fxt_segnext, OTXSegNext) + assert isinstance(fxt_segnext, SegNext) assert fxt_segnext.num_classes == 10 def test_load_from_otx_v1_ckpt(self, fxt_segnext, mocker): diff --git a/tests/unit/algo/utils/test_segmentation.py b/tests/unit/algo/utils/test_segmentation.py index dc6893ccdd5..14b6557b914 100644 --- a/tests/unit/algo/utils/test_segmentation.py +++ b/tests/unit/algo/utils/test_segmentation.py @@ -5,27 +5,17 @@ import torch from otx.algo.segmentation.modules import ( - AsymmetricPositionAttentionModule, IterativeAggregator, - LocalAttentionModule, channel_shuffle, normalize, ) -from otx.algo.segmentation.modules.blocks import OnnxLpNormalization, PSPModule +from otx.algo.segmentation.modules.blocks import OnnxLpNormalization def test_channel_shuffle(): assert channel_shuffle(torch.randn([1, 24, 8, 8]), 4).shape == torch.Size([1, 24, 8, 8]) -def test_psp_module(): - assert PSPModule().forward(torch.randn([1, 24, 28, 8])).shape == torch.Size([1, 24, 110]) - - -def test_asymmetric_position_attention_module(): - assert AsymmetricPositionAttentionModule(24, 48)(torch.randn([1, 24, 8, 8])).shape == torch.Size([1, 24, 8, 8]) - - def test_onnx_lp_normalization(): assert OnnxLpNormalization().forward(None, torch.randn([1, 24, 8, 8])).shape == torch.Size([1, 24, 8, 8]) @@ -49,7 +39,3 @@ def test_iterative_aggregator(): assert len(out) == 2 assert out[0].shape == torch.Size([1, 2, 16, 16]) assert out[1].shape == torch.Size([1, 2, 8, 8]) - - -def test_local_attention_module(): - assert LocalAttentionModule(24).forward(torch.randn([2, 24, 8, 8])).shape == torch.Size([2, 24, 8, 8]) diff --git a/tests/unit/core/model/test_segmentation.py b/tests/unit/core/model/test_segmentation.py index 130aa3a96dd..d364c9ab273 100644 --- a/tests/unit/core/model/test_segmentation.py +++ b/tests/unit/core/model/test_segmentation.py @@ -11,42 +11,46 @@ from otx.core.data.entity.segmentation import SegBatchDataEntity, SegBatchPredEntity from otx.core.metrics.dice import SegmCallable from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable -from otx.core.model.segmentation import OTXSegmentationModel, TorchVisionCompatibleModel +from otx.core.model.segmentation import OTXSegmentationModel from otx.core.types.label import SegLabelInfo -@pytest.fixture() -def label_info(): - return SegLabelInfo( - label_names=["Background", "label_0", "label_1"], - label_groups=[["Background", "label_0", "label_1"]], - ) - - -@pytest.fixture() -def optimizer(): - return DefaultOptimizerCallable - - -@pytest.fixture() -def scheduler(): - return DefaultSchedulerCallable +class TestOTXSegmentationModel: + @pytest.fixture() + def model(self, label_info, optimizer, scheduler, metric, torch_compile): + return OTXSegmentationModel(label_info, (512, 512), optimizer, scheduler, metric, torch_compile) + @pytest.fixture() + def batch_data_entity(self): + return SegBatchDataEntity( + batch_size=2, + images=torch.randn(2, 3, 224, 224), + imgs_info=[], + masks=[torch.randn(224, 224), torch.randn(224, 224)], + ) -@pytest.fixture() -def metric(): - return SegmCallable + @pytest.fixture() + def label_info(self): + return SegLabelInfo( + label_names=["Background", "label_0", "label_1"], + label_groups=[["Background", "label_0", "label_1"]], + ) + @pytest.fixture() + def optimizer(self): + return DefaultOptimizerCallable -@pytest.fixture() -def torch_compile(): - return False + @pytest.fixture() + def scheduler(self): + return DefaultSchedulerCallable + @pytest.fixture() + def metric(self): + return SegmCallable -class TestOTXSegmentationModel: @pytest.fixture() - def model(self, label_info, optimizer, scheduler, metric, torch_compile): - return OTXSegmentationModel(label_info, (512, 512), optimizer, scheduler, metric, torch_compile) + def torch_compile(self): + return False def test_export_parameters(self, model): params = model._export_parameters @@ -70,21 +74,6 @@ def test_dispatch_label_info(self, model, label_info, expected_label_info): result = model._dispatch_label_info(label_info) assert result == expected_label_info - -class TestTorchVisionCompatibleModel: - @pytest.fixture() - def model(self, label_info, optimizer, scheduler, metric, torch_compile) -> TorchVisionCompatibleModel: - return TorchVisionCompatibleModel(label_info, (512, 512), optimizer, scheduler, metric, torch_compile) - - @pytest.fixture() - def batch_data_entity(self): - return SegBatchDataEntity( - batch_size=2, - images=torch.randn(2, 3, 224, 224), - imgs_info=[], - masks=[torch.randn(224, 224), torch.randn(224, 224)], - ) - def test_init(self, model): assert model.num_classes == 3 @@ -109,7 +98,7 @@ def test_customize_outputs_predict(self, model, batch_data_entity): assert customized_outputs.images.shape == (2, 3, 224, 224) assert customized_outputs.imgs_info == [] - def test_dummy_input(self, model: TorchVisionCompatibleModel): + def test_dummy_input(self, model: OTXSegmentationModel): batch_size = 2 batch = model.get_dummy_input(batch_size) assert batch.batch_size == batch_size