Skip to content

Commit

Permalink
Refactor Semantic Segmentation models (#3840)
Browse files Browse the repository at this point in the history
* started to add MeanTeacher. Needs debugging

* added cutmix

* added other recipes. Updated linter

* fix unit test

* minor

* fix unit test. Add integration

* fix integration test

* add perf tests

* fix export. Delete student model

* remove dead code. Extend unit test

* update changelog

* remove HPO tests for semi-sl. Add Unit test for callback

* fix pylint

* fix test xai

* minor

* SegNExt refactored

* refactoring semantic seg

* unit tests are not working

* fix unit test

* minor

* fix unit test for utils

* fix last tests

* fix classification torchvision

* fix integration tests

* change mean, scale

* fix recipes

* fix recipes

* remove model from recipe

* remove merge problem

* fix pytorchcv version

---------

Co-authored-by: Shin, Eunwoo <eunwoo.shin@intel.com>
  • Loading branch information
kprokofi and eunwoosh authored Aug 16, 2024
1 parent 7b86e2d commit 53d54af
Show file tree
Hide file tree
Showing 48 changed files with 951 additions and 1,739 deletions.
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ xpu = [
"intel-extension-for-pytorch==2.1.30+xpu",
"oneccl_bind_pt==2.1.300+xpu",
"lightning==2.2",
"pytorchcv",
"pytorchcv==0.0.67",
"timm==1.0.3",
"openvino==2024.3",
"openvino-dev==2024.3",
Expand All @@ -93,7 +93,7 @@ xpu = [
base = [
"torch==2.2.2",
"lightning==2.3.3",
"pytorchcv",
"pytorchcv==0.0.67",
"timm==1.0.3",
"openvino==2024.3",
"openvino-dev==2024.3",
Expand Down
2 changes: 1 addition & 1 deletion src/otx/algo/classification/torchvision_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,6 @@ def __init__(
) -> None:
self.backbone = backbone
self.freeze_backbone = freeze_backbone
self.train_type = train_type
self.task = task

# TODO(@harimkang): Need to make it configurable.
Expand All @@ -447,6 +446,7 @@ def __init__(
metric=metric,
torch_compile=torch_compile,
input_size=input_size,
train_type=train_type,
)
self.input_size: tuple[int, int]

Expand Down
4 changes: 2 additions & 2 deletions src/otx/algo/segmentation/backbones/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"""Backbone modules for OTX segmentation model."""

from .dinov2 import DinoVisionTransformer
from .litehrnet import LiteHRNet
from .litehrnet import LiteHRNetBackbone
from .mscan import MSCAN

__all__ = ["LiteHRNet", "DinoVisionTransformer", "MSCAN"]
__all__ = ["LiteHRNetBackbone", "DinoVisionTransformer", "MSCAN"]
6 changes: 2 additions & 4 deletions src/otx/algo/segmentation/backbones/dinov2.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,25 +13,23 @@
import torch
from torch import nn

from otx.algo.modules.base_module import BaseModule
from otx.algo.utils.mmengine_utils import load_checkpoint_to_model, load_from_http
from otx.utils.utils import get_class_initial_arguments

logger = logging.getLogger()


class DinoVisionTransformer(BaseModule):
class DinoVisionTransformer(nn.Module):
"""DINO-v2 Model."""

def __init__(
self,
name: str,
freeze_backbone: bool,
out_index: list[int],
init_cfg: dict | None = None,
pretrained_weights: str | None = None,
):
super().__init__(init_cfg)
super().__init__()
self._init_args = get_class_initial_arguments()

ci_data_root = os.environ.get("CI_DATA_ROOT")
Expand Down
211 changes: 92 additions & 119 deletions src/otx/algo/segmentation/backbones/litehrnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,15 @@
from __future__ import annotations

from pathlib import Path
from typing import Callable
from typing import Any, Callable, ClassVar

import torch
import torch.utils.checkpoint as cp
from torch import nn
from torch.nn import functional

from otx.algo.modules import Conv2dModule, build_norm_layer
from otx.algo.modules.base_module import BaseModule
from otx.algo.segmentation.modules import (
AsymmetricPositionAttentionModule,
IterativeAggregator,
LocalAttentionModule,
channel_shuffle,
)
from otx.algo.utils.mmengine_utils import load_checkpoint_to_model, load_from_http
Expand Down Expand Up @@ -1191,7 +1187,7 @@ def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
return out


class LiteHRNet(BaseModule):
class NNLiteHRNet(nn.Module):
"""Lite-HRNet backbone.
`High-Resolution Representations for Labeling Pixels and Regions
Expand All @@ -1212,44 +1208,34 @@ class LiteHRNet(BaseModule):

def __init__(
self,
extra: dict,
stem: dict[str, Any],
num_stages: int,
stages_spec: dict[str, Any],
in_channels: int = 3,
norm_cfg: dict | None = None,
norm_cfg: dict[str, Any] | None = None,
norm_eval: bool = False,
with_cp: bool = False,
zero_init_residual: bool = False,
dropout: float | None = None,
init_cfg: dict | None = None,
pretrained_weights: str | None = None,
) -> None:
"""Init."""
super().__init__(init_cfg=init_cfg)
super().__init__()

if norm_cfg is None:
norm_cfg = {"type": "BN"}
norm_cfg = {"type": "BN", "requires_grad": True}

self.extra = extra
self.norm_cfg = norm_cfg
self.norm_eval = norm_eval
self.with_cp = with_cp
self.zero_init_residual = zero_init_residual
self.stem = Stem(
in_channels,
input_norm=self.extra["stem"]["input_norm"],
stem_channels=self.extra["stem"]["stem_channels"],
out_channels=self.extra["stem"]["out_channels"],
expand_ratio=self.extra["stem"]["expand_ratio"],
strides=self.extra["stem"]["strides"],
extra_stride=self.extra["stem"]["extra_stride"],
norm_cfg=self.norm_cfg,
**stem,
)

self.enable_stem_pool = self.extra["stem"].get("out_pool", False)
if self.enable_stem_pool:
self.stem_pool = nn.AvgPool2d(kernel_size=3, stride=2)

self.num_stages = self.extra["num_stages"]
self.stages_spec = self.extra["stages_spec"]
self.num_stages = num_stages
self.stages_spec = stages_spec

num_channels_last = [
self.stem.out_channels,
Expand All @@ -1273,80 +1259,6 @@ def __init__(
)
setattr(self, f"stage{i}", stage)

self.out_modules = None
if self.extra.get("out_modules") is not None:
out_modules = []
in_modules_channels, out_modules_channels = num_channels_last[-1], None
if self.extra["out_modules"]["conv"]["enable"]:
out_modules_channels = self.extra["out_modules"]["conv"]["channels"]
out_modules.append(
Conv2dModule(
in_channels=in_modules_channels,
out_channels=out_modules_channels,
kernel_size=1,
stride=1,
padding=0,
norm_cfg=self.norm_cfg,
activation_callable=nn.ReLU,
),
)
in_modules_channels = out_modules_channels
if self.extra["out_modules"]["position_att"]["enable"]:
out_modules.append(
AsymmetricPositionAttentionModule(
in_channels=in_modules_channels,
key_channels=self.extra["out_modules"]["position_att"]["key_channels"],
value_channels=self.extra["out_modules"]["position_att"]["value_channels"],
psp_size=self.extra["out_modules"]["position_att"]["psp_size"],
norm_cfg=self.norm_cfg,
),
)
if self.extra["out_modules"]["local_att"]["enable"]:
out_modules.append(
LocalAttentionModule(
num_channels=in_modules_channels,
norm_cfg=self.norm_cfg,
),
)

if len(out_modules) > 0:
self.out_modules = nn.Sequential(*out_modules)
num_channels_last.append(in_modules_channels)

self.add_stem_features = self.extra.get("add_stem_features", False)
if self.add_stem_features:
self.stem_transition = nn.Sequential(
Conv2dModule(
self.stem.out_channels,
self.stem.out_channels,
kernel_size=3,
stride=1,
padding=1,
groups=self.stem.out_channels,
norm_cfg=norm_cfg,
activation_callable=None,
),
Conv2dModule(
self.stem.out_channels,
num_channels_last[0],
kernel_size=1,
stride=1,
padding=0,
norm_cfg=norm_cfg,
activation_callable=nn.ReLU,
),
)

num_channels_last = [num_channels_last[0], *num_channels_last]

self.with_aggregator = self.extra.get("out_aggregator") and self.extra["out_aggregator"]["enable"]
if self.with_aggregator:
self.aggregator = IterativeAggregator(
in_channels=num_channels_last,
min_channels=self.extra["out_aggregator"].get("min_channels", None),
norm_cfg=self.norm_cfg,
)

if pretrained_weights is not None:
self.load_pretrained_weights(pretrained_weights, prefix="backbone")

Expand Down Expand Up @@ -1479,11 +1391,7 @@ def _make_stage(
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward function."""
stem_outputs = self.stem(x)
y_x2 = y_x4 = stem_outputs
y = y_x4

if self.enable_stem_pool:
y = self.stem_pool(y)
y = stem_outputs

y_list = [y]
for i in range(self.num_stages):
Expand All @@ -1502,21 +1410,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
stage_module = getattr(self, f"stage{i}")
y_list = stage_module(stage_inputs)

if self.out_modules is not None:
y_list.append(self.out_modules(y_list[-1]))

if self.add_stem_features:
y_stem = self.stem_transition(y_x2)
y_list = [y_stem, *y_list]

out = y_list
if self.with_aggregator:
out = self.aggregator(out)

if self.extra.get("add_input", False):
out = [x, *out]

return out
return y_list

def load_pretrained_weights(self, pretrained: str | None = None, prefix: str = "") -> None:
"""Initialize weights."""
Expand All @@ -1530,3 +1424,82 @@ def load_pretrained_weights(self, pretrained: str | None = None, prefix: str = "
print(f"init weight - {pretrained}")
if checkpoint is not None:
load_checkpoint_to_model(self, checkpoint, prefix=prefix)


class LiteHRNetBackbone:
"""LiteHRNet backbone factory."""

LITEHRNET_CFG: ClassVar[dict[str, Any]] = {
"lite_hrnet_s": {
"stem": {
"stem_channels": 32,
"out_channels": 32,
"expand_ratio": 1,
"strides": [2, 2],
"extra_stride": True,
"input_norm": False,
},
"num_stages": 2,
"stages_spec": {
"num_modules": [4, 4],
"num_branches": [2, 3],
"num_blocks": [2, 2],
"module_type": ["LITE", "LITE"],
"with_fuse": [True, True],
"reduce_ratios": [8, 8],
"num_channels": [[60, 120], [60, 120, 240]],
},
"pretrained_weights": "https://storage.openvinotoolkit.org/repositories/openvino_training_extensions/models/custom_semantic_segmentation/litehrnetsv2_imagenet1k_rsc.pth",
},
"lite_hrnet_18": {
"stem": {
"stem_channels": 32,
"out_channels": 32,
"expand_ratio": 1,
"strides": [2, 2],
"extra_stride": False,
"input_norm": False,
},
"num_stages": 3,
"stages_spec": {
"num_modules": [2, 4, 2],
"num_branches": [2, 3, 4],
"num_blocks": [2, 2, 2],
"module_type": ["LITE", "LITE", "LITE"],
"with_fuse": [True, True, True],
"reduce_ratios": [8, 8, 8],
"num_channels": [[40, 80], [40, 80, 160], [40, 80, 160, 320]],
},
"pretrained_weights": "https://storage.openvinotoolkit.org/repositories/openvino_training_extensions/models/custom_semantic_segmentation/litehrnet18_imagenet1k_rsc.pth",
},
"lite_hrnet_x": {
"stem": {
"stem_channels": 60,
"out_channels": 60,
"expand_ratio": 1,
"strides": [2, 1],
"extra_stride": False,
"input_norm": False,
},
"num_stages": 4,
"stages_spec": {
"weighting_module_version": "v1",
"num_modules": [2, 4, 4, 2],
"num_branches": [2, 3, 4, 5],
"num_blocks": [2, 2, 2, 2],
"module_type": ["LITE", "LITE", "LITE", "LITE"],
"with_fuse": [True, True, True, True],
"reduce_ratios": [2, 4, 8, 8],
"num_channels": [[18, 60], [18, 60, 80], [18, 60, 80, 160], [18, 60, 80, 160, 320]],
},
"pretrained_weights": "https://storage.openvinotoolkit.org/repositories/openvino_training_extensions/models/custom_semantic_segmentation/litehrnetxv3_imagenet1k_rsc.pth",
},
}

def __new__(cls, version: str) -> NNLiteHRNet:
"""Constructor for LiteHRNet backbone."""
if version not in cls.LITEHRNET_CFG:
msg = f"model type '{version}' is not supported"
raise KeyError(msg)

return NNLiteHRNet(**cls.LITEHRNET_CFG[version])
Loading

0 comments on commit 53d54af

Please sign in to comment.