Skip to content

Commit

Permalink
Add Semi-SL MeanTeacher for semantic segmentation (#3801)
Browse files Browse the repository at this point in the history
* started to add MeanTeacher. Needs debugging

* added cutmix

* added other recipes. Updated linter

* fix unit test

* minor

* fix unit test. Add integration

* fix integration test

* add perf tests

* fix export. Delete student model

* remove dead code. Extend unit test

* update changelog

* remove HPO tests for semi-sl. Add Unit test for callback

* fix pylint

* fix test xai

* minor
  • Loading branch information
kprokofi authored Aug 13, 2024
1 parent 962d26c commit 2ecaac1
Show file tree
Hide file tree
Showing 43 changed files with 1,329 additions and 73 deletions.
6 changes: 4 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@ All notable changes to this project will be documented in this file.

### New features

- Add RT-DETR model for object detection task
- Add RT-DETR model for Object Detection
(https://github.com/openvinotoolkit/training_extensions/pull/3741)
- Add Multi-Label & H-label Classification with torchvision models
(https://github.com/openvinotoolkit/training_extensions/pull/3697)
- Add Hugging-Face Model Wrapper for Classification
(https://github.com/openvinotoolkit/training_extensions/pull/3710)
- Add LoRA finetuning capability for ViT Architectures
(https://github.com/openvinotoolkit/training_extensions/pull/3729)
- Add Hugging-Face Model Wrapper for Detection
- Add Hugging-Face Model Wrapper for Object Detection
(https://github.com/openvinotoolkit/training_extensions/pull/3747)
- Add Hugging-Face Model Wrapper for Semantic Segmentation
(https://github.com/openvinotoolkit/training_extensions/pull/3749)
Expand All @@ -24,6 +24,8 @@ All notable changes to this project will be documented in this file.
(https://github.com/openvinotoolkit/training_extensions/pull/3762)
- Add RTMPose for Keypoint Detection Task
(https://github.com/openvinotoolkit/training_extensions/pull/3781)
- Add Semi-SL MeanTeacher algorithm for Semantic Segmentation
(https://github.com/openvinotoolkit/training_extensions/pull/3801)
- Update head and h-label format for hierarchical label classification
(https://github.com/openvinotoolkit/training_extensions/pull/3810)

Expand Down
81 changes: 81 additions & 0 deletions src/otx/algo/callbacks/ema_mean_teacher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
"""Module for exponential moving average for SemiSL mean teacher algorithm."""

from __future__ import annotations

from typing import TYPE_CHECKING, Any

import torch
from lightning import Callback, LightningModule, Trainer

if TYPE_CHECKING:
from lightning.pytorch.utilities.types import STEP_OUTPUT


class EMAMeanTeacher(Callback):
"""callback for SemiSL MeanTeacher algorithm.
This callback averages the weights of the teacher model.
Args:
momentum (float, optional): momentum. Defaults to 0.999.
start_epoch (int, optional): start epoch. Defaults to 1.
"""

def __init__(
self,
momentum: float = 0.999,
start_epoch: int = 1,
) -> None:
super().__init__()
self.momentum = momentum
self.start_epoch = start_epoch
self.synced_models = False

def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
"""Set up src & dst model parameters."""
# call to nn.model
model = trainer.model.model
self.src_model = getattr(model, "student_model", None)
self.dst_model = getattr(model, "teacher_model", None)
if self.src_model is None or self.dst_model is None:
msg = "student_model and teacher_model should be set for MeanTeacher algorithm"
raise RuntimeError(msg)
self.src_params = self.src_model.state_dict(keep_vars=True)
self.dst_params = self.dst_model.state_dict(keep_vars=True)

def on_train_batch_end(
self,
trainer: Trainer,
pl_module: LightningModule,
outputs: STEP_OUTPUT,
batch: Any, # noqa: ANN401
batch_idx: int,
) -> None:
"""Update ema parameter every iteration."""
if trainer.current_epoch < self.start_epoch:
return

# EMA
self._ema_model(trainer.global_step)

def _copy_model(self) -> None:
with torch.no_grad():
for name, src_param in self.src_params.items():
if src_param.requires_grad:
dst_param = self.dst_params[name]
dst_param.data.copy_(src_param.data)

def _ema_model(self, global_step: int) -> None:
if self.start_epoch != 0 and not self.synced_models:
self._copy_model()
self.synced_models = True

momentum = min(1 - 1 / (global_step + 1), self.momentum)
with torch.no_grad():
for name, src_param in self.src_params.items():
if src_param.requires_grad:
dst_param = self.dst_params[name]
dst_param.data.copy_(dst_param.data * momentum + src_param.data * (1 - momentum))
67 changes: 67 additions & 0 deletions src/otx/algo/common/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from functools import partial
from typing import Callable

import numpy as np
import torch
import torch.distributed as dist
from torch import Tensor
Expand Down Expand Up @@ -259,3 +260,69 @@ def inverse_sigmoid(x: Tensor, eps: float = 1e-5) -> Tensor:
x1 = x.clamp(min=eps)
x2 = (1 - x).clamp(min=eps)
return torch.log(x1 / x2)


def cut_mixer(images: Tensor, masks: Tensor) -> tuple[Tensor, Tensor]:
"""Applies cut-mix augmentation to the input images and masks.
Args:
images (Tensor): The input images tensor.
masks (Tensor): The input masks tensor.
Returns:
tuple[Tensor, Tensor]: A tuple containing the augmented images and masks tensors.
"""

def rand_bbox(size: tuple[int, ...], lam: float) -> tuple[list[int], ...]:
"""Generates random bounding box coordinates.
Args:
size (tuple[int, ...]): The size of the input tensor.
lam (float): The lambda value for cut-mix augmentation.
Returns:
tuple[list[int, ...], ...]: The bounding box coordinates (bbx1, bby1, bbx2, bby2).
"""
# past implementation
w = size[2]
h = size[3]
b = size[0]
cut_rat = np.sqrt(1.0 - lam)
cut_w = int(w * cut_rat)
cut_h = int(h * cut_rat)

cx = np.random.randint(size=[b], low=int(w / 8), high=w)
cy = np.random.randint(size=[b], low=int(h / 8), high=h)

bbx1 = np.clip(cx - cut_w // 2, 0, w)
bby1 = np.clip(cy - cut_h // 2, 0, h)

bbx2 = np.clip(cx + cut_w // 2, 0, w)
bby2 = np.clip(cy + cut_h // 2, 0, h)

return bbx1, bby1, bbx2, bby2

target_device = images.device
mix_data = images.clone()
mix_masks = masks.clone()
u_rand_index = torch.randperm(images.size()[0])[: images.size()[0]].to(target_device)
u_bbx1, u_bby1, u_bbx2, u_bby2 = rand_bbox(images.size(), lam=np.random.beta(4, 4))

for i in range(mix_data.shape[0]):
mix_data[i, :, u_bbx1[i] : u_bbx2[i], u_bby1[i] : u_bby2[i]] = images[
u_rand_index[i],
:,
u_bbx1[i] : u_bbx2[i],
u_bby1[i] : u_bby2[i],
]

mix_masks[i, :, u_bbx1[i] : u_bbx2[i], u_bby1[i] : u_bby2[i]] = masks[
u_rand_index[i],
:,
u_bbx1[i] : u_bbx2[i],
u_bby1[i] : u_bby2[i],
]

del images, masks

return mix_data, mix_masks.squeeze(dim=1)
4 changes: 2 additions & 2 deletions src/otx/algo/segmentation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@
#
"""Module for OTX segmentation models, hooks, utils, etc."""

from . import backbones, heads, losses
from . import backbones, heads, losses, segmentors

__all__ = ["backbones", "heads", "losses"]
__all__ = ["backbones", "heads", "losses", "segmentors"]
49 changes: 47 additions & 2 deletions src/otx/algo/segmentation/dino_v2_seg.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@

from typing import TYPE_CHECKING, Any, ClassVar

import torch

from otx.algo.segmentation.backbones import DinoVisionTransformer
from otx.algo.segmentation.heads import FCNHead
from otx.algo.segmentation.segmentors import BaseSegmModel, MeanTeacher
from otx.core.data.entity.segmentation import SegBatchDataEntity
from otx.core.model.segmentation import TorchVisionCompatibleModel

from .base_model import BaseSegmModel

if TYPE_CHECKING:
from torch import nn
from typing_extensions import Self
Expand Down Expand Up @@ -68,3 +70,46 @@ def to(self, *args, **kwargs) -> Self:
msg = f"{type(self).__name__} doesn't support XPU."
raise RuntimeError(msg)
return ret


class DinoV2SegSemiSL(OTXDinoV2Seg):
"""DinoV2SegSemiSL Model."""

def _customize_inputs(self, entity: SegBatchDataEntity) -> dict[str, Any]:
if not isinstance(entity, dict):
if self.training:
msg = "unlabeled inputs should be provided for semi-sl training"
raise RuntimeError(msg)
return super()._customize_inputs(entity)

entity["labeled"].masks = torch.stack(entity["labeled"].masks).long()
w_u_images = entity["weak_transforms"].images
s_u_images = entity["strong_transforms"].images
unlabeled_img_metas = entity["weak_transforms"].imgs_info
labeled_inputs = entity["labeled"]

return {
"inputs": labeled_inputs.images,
"unlabeled_weak_images": w_u_images,
"unlabeled_strong_images": s_u_images,
"global_step": self.trainer.global_step,
"steps_per_epoch": self.trainer.num_training_batches,
"img_metas": labeled_inputs.imgs_info,
"unlabeled_img_metas": unlabeled_img_metas,
"masks": labeled_inputs.masks,
"mode": "loss",
}

def _create_model(self) -> nn.Module:
# merge configurations with defaults overriding them
backbone_configuration = DinoV2Seg.default_backbone_configuration | self.backbone_configuration
decode_head_configuration = DinoV2Seg.default_decode_head_configuration | self.decode_head_configuration
backbone = DinoVisionTransformer(**backbone_configuration)
decode_head = FCNHead(num_classes=self.num_classes, **decode_head_configuration)
base_model = DinoV2Seg(
backbone=backbone,
decode_head=decode_head,
criterion_configuration=self.criterion_configuration,
)

return MeanTeacher(base_model, unsup_weight=0.7, drop_unrel_pixels_percent=20, semisl_start_epoch=2)
53 changes: 51 additions & 2 deletions src/otx/algo/segmentation/litehrnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,18 @@

from typing import TYPE_CHECKING, Any, ClassVar

import torch
from torch.onnx import OperatorExportTypes

from otx.algo.segmentation.backbones import LiteHRNet
from otx.algo.segmentation.heads import FCNHead
from otx.algo.segmentation.segmentors import BaseSegmModel, MeanTeacher
from otx.algo.utils.support_otx_v1 import OTXv1Helper
from otx.core.data.entity.segmentation import SegBatchDataEntity
from otx.core.exporter.base import OTXModelExporter
from otx.core.exporter.native import OTXNativeModelExporter
from otx.core.model.segmentation import TorchVisionCompatibleModel

from .base_model import BaseSegmModel

if TYPE_CHECKING:
from torch import nn

Expand Down Expand Up @@ -574,3 +575,51 @@ def _exporter(self) -> OTXModelExporter:
onnx_export_configuration={"operator_export_type": OperatorExportTypes.ONNX_ATEN_FALLBACK},
output_names=None,
)


class LiteHRNetSemiSL(OTXLiteHRNet):
"""LiteHRNetSemiSL Model."""

def _create_model(self) -> nn.Module:
litehrnet_model_class = LITEHRNET_VARIANTS[self.name_base_model]
# merge configurations with defaults overriding them
backbone_configuration = litehrnet_model_class.default_backbone_configuration | self.backbone_configuration
decode_head_configuration = (
litehrnet_model_class.default_decode_head_configuration | self.decode_head_configuration
)
# initialize backbones
backbone = LiteHRNet(**backbone_configuration)
decode_head = FCNHead(num_classes=self.num_classes, **decode_head_configuration)

base_model = litehrnet_model_class(
backbone=backbone,
decode_head=decode_head,
criterion_configuration=self.criterion_configuration,
)

return MeanTeacher(base_model, unsup_weight=0.7, drop_unrel_pixels_percent=20, semisl_start_epoch=2)

def _customize_inputs(self, entity: SegBatchDataEntity) -> dict[str, Any]:
if not isinstance(entity, dict):
if self.training:
msg = "unlabeled inputs should be provided for semi-sl training"
raise RuntimeError(msg)
return super()._customize_inputs(entity)

entity["labeled"].masks = torch.stack(entity["labeled"].masks).long()
w_u_images = entity["weak_transforms"].images
s_u_images = entity["strong_transforms"].images
unlabeled_img_metas = entity["weak_transforms"].imgs_info
labeled_inputs = entity["labeled"]

return {
"inputs": labeled_inputs.images,
"unlabeled_weak_images": w_u_images,
"unlabeled_strong_images": s_u_images,
"global_step": self.trainer.global_step,
"steps_per_epoch": self.trainer.num_training_batches,
"img_metas": labeled_inputs.imgs_info,
"unlabeled_img_metas": unlabeled_img_metas,
"masks": labeled_inputs.masks,
"mode": "loss",
}
9 changes: 9 additions & 0 deletions src/otx/algo/segmentation/segmentors/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
"""Module for base NN segmentation models."""

from .base_model import BaseSegmModel
from .mean_teacher import MeanTeacher

__all__ = ["BaseSegmModel", "MeanTeacher"]
Loading

0 comments on commit 2ecaac1

Please sign in to comment.