Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Semi-SL MeanTeacher for semantic segmentation #3801

Merged
merged 22 commits into from
Aug 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@ All notable changes to this project will be documented in this file.

### New features

- Add RT-DETR model for object detection task
- Add RT-DETR model for Object Detection
(https://github.com/openvinotoolkit/training_extensions/pull/3741)
- Add Multi-Label & H-label Classification with torchvision models
(https://github.com/openvinotoolkit/training_extensions/pull/3697)
- Add Hugging-Face Model Wrapper for Classification
(https://github.com/openvinotoolkit/training_extensions/pull/3710)
- Add LoRA finetuning capability for ViT Architectures
(https://github.com/openvinotoolkit/training_extensions/pull/3729)
- Add Hugging-Face Model Wrapper for Detection
- Add Hugging-Face Model Wrapper for Object Detection
(https://github.com/openvinotoolkit/training_extensions/pull/3747)
- Add Hugging-Face Model Wrapper for Semantic Segmentation
(https://github.com/openvinotoolkit/training_extensions/pull/3749)
Expand All @@ -24,6 +24,8 @@ All notable changes to this project will be documented in this file.
(https://github.com/openvinotoolkit/training_extensions/pull/3762)
- Add RTMPose for Keypoint Detection Task
(https://github.com/openvinotoolkit/training_extensions/pull/3781)
- Add Semi-SL MeanTeacher algorithm for Semantic Segmentation
(https://github.com/openvinotoolkit/training_extensions/pull/3801)
- Update head and h-label format for hierarchical label classification
(https://github.com/openvinotoolkit/training_extensions/pull/3810)

Expand Down
81 changes: 81 additions & 0 deletions src/otx/algo/callbacks/ema_mean_teacher.py
kprokofi marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
"""Module for exponential moving average for SemiSL mean teacher algorithm."""

from __future__ import annotations

from typing import TYPE_CHECKING, Any

import torch
from lightning import Callback, LightningModule, Trainer

if TYPE_CHECKING:
from lightning.pytorch.utilities.types import STEP_OUTPUT


class EMAMeanTeacher(Callback):
"""callback for SemiSL MeanTeacher algorithm.

This callback averages the weights of the teacher model.

Args:
momentum (float, optional): momentum. Defaults to 0.999.
start_epoch (int, optional): start epoch. Defaults to 1.
"""

def __init__(
self,
momentum: float = 0.999,
start_epoch: int = 1,
) -> None:
super().__init__()
self.momentum = momentum
self.start_epoch = start_epoch
self.synced_models = False

def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
"""Set up src & dst model parameters."""
# call to nn.model
model = trainer.model.model
self.src_model = getattr(model, "student_model", None)
self.dst_model = getattr(model, "teacher_model", None)
if self.src_model is None or self.dst_model is None:
msg = "student_model and teacher_model should be set for MeanTeacher algorithm"
raise RuntimeError(msg)

Check warning on line 45 in src/otx/algo/callbacks/ema_mean_teacher.py

View check run for this annotation

Codecov / codecov/patch

src/otx/algo/callbacks/ema_mean_teacher.py#L44-L45

Added lines #L44 - L45 were not covered by tests
self.src_params = self.src_model.state_dict(keep_vars=True)
self.dst_params = self.dst_model.state_dict(keep_vars=True)

def on_train_batch_end(
self,
trainer: Trainer,
pl_module: LightningModule,
outputs: STEP_OUTPUT,
batch: Any, # noqa: ANN401
batch_idx: int,
) -> None:
"""Update ema parameter every iteration."""
if trainer.current_epoch < self.start_epoch:
return

Check warning on line 59 in src/otx/algo/callbacks/ema_mean_teacher.py

View check run for this annotation

Codecov / codecov/patch

src/otx/algo/callbacks/ema_mean_teacher.py#L59

Added line #L59 was not covered by tests

# EMA
self._ema_model(trainer.global_step)

def _copy_model(self) -> None:
with torch.no_grad():
for name, src_param in self.src_params.items():
if src_param.requires_grad:
dst_param = self.dst_params[name]
dst_param.data.copy_(src_param.data)

def _ema_model(self, global_step: int) -> None:
if self.start_epoch != 0 and not self.synced_models:
self._copy_model()
self.synced_models = True

momentum = min(1 - 1 / (global_step + 1), self.momentum)
with torch.no_grad():
for name, src_param in self.src_params.items():
if src_param.requires_grad:
dst_param = self.dst_params[name]
dst_param.data.copy_(dst_param.data * momentum + src_param.data * (1 - momentum))
67 changes: 67 additions & 0 deletions src/otx/algo/common/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from functools import partial
from typing import Callable

import numpy as np
import torch
import torch.distributed as dist
from torch import Tensor
Expand Down Expand Up @@ -259,3 +260,69 @@ def inverse_sigmoid(x: Tensor, eps: float = 1e-5) -> Tensor:
x1 = x.clamp(min=eps)
x2 = (1 - x).clamp(min=eps)
return torch.log(x1 / x2)


def cut_mixer(images: Tensor, masks: Tensor) -> tuple[Tensor, Tensor]:
sungchul2 marked this conversation as resolved.
Show resolved Hide resolved
"""Applies cut-mix augmentation to the input images and masks.
Args:
images (Tensor): The input images tensor.
masks (Tensor): The input masks tensor.
Returns:
tuple[Tensor, Tensor]: A tuple containing the augmented images and masks tensors.
"""

def rand_bbox(size: tuple[int, ...], lam: float) -> tuple[list[int], ...]:
"""Generates random bounding box coordinates.
Args:
size (tuple[int, ...]): The size of the input tensor.
lam (float): The lambda value for cut-mix augmentation.
Returns:
tuple[list[int, ...], ...]: The bounding box coordinates (bbx1, bby1, bbx2, bby2).
"""
# past implementation
w = size[2]
h = size[3]
b = size[0]
cut_rat = np.sqrt(1.0 - lam)
cut_w = int(w * cut_rat)
cut_h = int(h * cut_rat)

cx = np.random.randint(size=[b], low=int(w / 8), high=w)
cy = np.random.randint(size=[b], low=int(h / 8), high=h)

bbx1 = np.clip(cx - cut_w // 2, 0, w)
bby1 = np.clip(cy - cut_h // 2, 0, h)

bbx2 = np.clip(cx + cut_w // 2, 0, w)
bby2 = np.clip(cy + cut_h // 2, 0, h)

return bbx1, bby1, bbx2, bby2

target_device = images.device
mix_data = images.clone()
mix_masks = masks.clone()
u_rand_index = torch.randperm(images.size()[0])[: images.size()[0]].to(target_device)
u_bbx1, u_bby1, u_bbx2, u_bby2 = rand_bbox(images.size(), lam=np.random.beta(4, 4))

for i in range(mix_data.shape[0]):
mix_data[i, :, u_bbx1[i] : u_bbx2[i], u_bby1[i] : u_bby2[i]] = images[
u_rand_index[i],
:,
u_bbx1[i] : u_bbx2[i],
u_bby1[i] : u_bby2[i],
]

mix_masks[i, :, u_bbx1[i] : u_bbx2[i], u_bby1[i] : u_bby2[i]] = masks[
u_rand_index[i],
:,
u_bbx1[i] : u_bbx2[i],
u_bby1[i] : u_bby2[i],
]

del images, masks

return mix_data, mix_masks.squeeze(dim=1)
4 changes: 2 additions & 2 deletions src/otx/algo/segmentation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@
#
"""Module for OTX segmentation models, hooks, utils, etc."""

from . import backbones, heads, losses
from . import backbones, heads, losses, segmentors

__all__ = ["backbones", "heads", "losses"]
__all__ = ["backbones", "heads", "losses", "segmentors"]
49 changes: 47 additions & 2 deletions src/otx/algo/segmentation/dino_v2_seg.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@

from typing import TYPE_CHECKING, Any, ClassVar

import torch

from otx.algo.segmentation.backbones import DinoVisionTransformer
from otx.algo.segmentation.heads import FCNHead
from otx.algo.segmentation.segmentors import BaseSegmModel, MeanTeacher
from otx.core.data.entity.segmentation import SegBatchDataEntity
from otx.core.model.segmentation import TorchVisionCompatibleModel

from .base_model import BaseSegmModel

if TYPE_CHECKING:
from torch import nn
from typing_extensions import Self
Expand Down Expand Up @@ -68,3 +70,46 @@
msg = f"{type(self).__name__} doesn't support XPU."
raise RuntimeError(msg)
return ret


class DinoV2SegSemiSL(OTXDinoV2Seg):
"""DinoV2SegSemiSL Model."""

def _customize_inputs(self, entity: SegBatchDataEntity) -> dict[str, Any]:
if not isinstance(entity, dict):
if self.training:
msg = "unlabeled inputs should be provided for semi-sl training"
raise RuntimeError(msg)
return super()._customize_inputs(entity)

Check warning on line 83 in src/otx/algo/segmentation/dino_v2_seg.py

View check run for this annotation

Codecov / codecov/patch

src/otx/algo/segmentation/dino_v2_seg.py#L79-L83

Added lines #L79 - L83 were not covered by tests

entity["labeled"].masks = torch.stack(entity["labeled"].masks).long()
w_u_images = entity["weak_transforms"].images
s_u_images = entity["strong_transforms"].images
unlabeled_img_metas = entity["weak_transforms"].imgs_info
labeled_inputs = entity["labeled"]

Check warning on line 89 in src/otx/algo/segmentation/dino_v2_seg.py

View check run for this annotation

Codecov / codecov/patch

src/otx/algo/segmentation/dino_v2_seg.py#L85-L89

Added lines #L85 - L89 were not covered by tests

return {

Check warning on line 91 in src/otx/algo/segmentation/dino_v2_seg.py

View check run for this annotation

Codecov / codecov/patch

src/otx/algo/segmentation/dino_v2_seg.py#L91

Added line #L91 was not covered by tests
"inputs": labeled_inputs.images,
"unlabeled_weak_images": w_u_images,
"unlabeled_strong_images": s_u_images,
"global_step": self.trainer.global_step,
"steps_per_epoch": self.trainer.num_training_batches,
"img_metas": labeled_inputs.imgs_info,
"unlabeled_img_metas": unlabeled_img_metas,
"masks": labeled_inputs.masks,
"mode": "loss",
}

def _create_model(self) -> nn.Module:
# merge configurations with defaults overriding them
backbone_configuration = DinoV2Seg.default_backbone_configuration | self.backbone_configuration
decode_head_configuration = DinoV2Seg.default_decode_head_configuration | self.decode_head_configuration
backbone = DinoVisionTransformer(**backbone_configuration)
decode_head = FCNHead(num_classes=self.num_classes, **decode_head_configuration)
base_model = DinoV2Seg(

Check warning on line 109 in src/otx/algo/segmentation/dino_v2_seg.py

View check run for this annotation

Codecov / codecov/patch

src/otx/algo/segmentation/dino_v2_seg.py#L105-L109

Added lines #L105 - L109 were not covered by tests
backbone=backbone,
decode_head=decode_head,
criterion_configuration=self.criterion_configuration,
)

return MeanTeacher(base_model, unsup_weight=0.7, drop_unrel_pixels_percent=20, semisl_start_epoch=2)

Check warning on line 115 in src/otx/algo/segmentation/dino_v2_seg.py

View check run for this annotation

Codecov / codecov/patch

src/otx/algo/segmentation/dino_v2_seg.py#L115

Added line #L115 was not covered by tests
harimkang marked this conversation as resolved.
Show resolved Hide resolved
53 changes: 51 additions & 2 deletions src/otx/algo/segmentation/litehrnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,18 @@

from typing import TYPE_CHECKING, Any, ClassVar

import torch
from torch.onnx import OperatorExportTypes

from otx.algo.segmentation.backbones import LiteHRNet
from otx.algo.segmentation.heads import FCNHead
from otx.algo.segmentation.segmentors import BaseSegmModel, MeanTeacher
from otx.algo.utils.support_otx_v1 import OTXv1Helper
from otx.core.data.entity.segmentation import SegBatchDataEntity
from otx.core.exporter.base import OTXModelExporter
from otx.core.exporter.native import OTXNativeModelExporter
from otx.core.model.segmentation import TorchVisionCompatibleModel

from .base_model import BaseSegmModel

if TYPE_CHECKING:
from torch import nn

Expand Down Expand Up @@ -574,3 +575,51 @@
onnx_export_configuration={"operator_export_type": OperatorExportTypes.ONNX_ATEN_FALLBACK},
output_names=None,
)


class LiteHRNetSemiSL(OTXLiteHRNet):
"""LiteHRNetSemiSL Model."""

def _create_model(self) -> nn.Module:
litehrnet_model_class = LITEHRNET_VARIANTS[self.name_base_model]

Check warning on line 584 in src/otx/algo/segmentation/litehrnet.py

View check run for this annotation

Codecov / codecov/patch

src/otx/algo/segmentation/litehrnet.py#L584

Added line #L584 was not covered by tests
# merge configurations with defaults overriding them
backbone_configuration = litehrnet_model_class.default_backbone_configuration | self.backbone_configuration
decode_head_configuration = (

Check warning on line 587 in src/otx/algo/segmentation/litehrnet.py

View check run for this annotation

Codecov / codecov/patch

src/otx/algo/segmentation/litehrnet.py#L586-L587

Added lines #L586 - L587 were not covered by tests
litehrnet_model_class.default_decode_head_configuration | self.decode_head_configuration
)
# initialize backbones
backbone = LiteHRNet(**backbone_configuration)
decode_head = FCNHead(num_classes=self.num_classes, **decode_head_configuration)

Check warning on line 592 in src/otx/algo/segmentation/litehrnet.py

View check run for this annotation

Codecov / codecov/patch

src/otx/algo/segmentation/litehrnet.py#L591-L592

Added lines #L591 - L592 were not covered by tests

base_model = litehrnet_model_class(

Check warning on line 594 in src/otx/algo/segmentation/litehrnet.py

View check run for this annotation

Codecov / codecov/patch

src/otx/algo/segmentation/litehrnet.py#L594

Added line #L594 was not covered by tests
backbone=backbone,
decode_head=decode_head,
criterion_configuration=self.criterion_configuration,
)

return MeanTeacher(base_model, unsup_weight=0.7, drop_unrel_pixels_percent=20, semisl_start_epoch=2)

Check warning on line 600 in src/otx/algo/segmentation/litehrnet.py

View check run for this annotation

Codecov / codecov/patch

src/otx/algo/segmentation/litehrnet.py#L600

Added line #L600 was not covered by tests

def _customize_inputs(self, entity: SegBatchDataEntity) -> dict[str, Any]:
if not isinstance(entity, dict):
if self.training:
msg = "unlabeled inputs should be provided for semi-sl training"
raise RuntimeError(msg)
return super()._customize_inputs(entity)

Check warning on line 607 in src/otx/algo/segmentation/litehrnet.py

View check run for this annotation

Codecov / codecov/patch

src/otx/algo/segmentation/litehrnet.py#L603-L607

Added lines #L603 - L607 were not covered by tests

entity["labeled"].masks = torch.stack(entity["labeled"].masks).long()
w_u_images = entity["weak_transforms"].images
s_u_images = entity["strong_transforms"].images
unlabeled_img_metas = entity["weak_transforms"].imgs_info
labeled_inputs = entity["labeled"]

Check warning on line 613 in src/otx/algo/segmentation/litehrnet.py

View check run for this annotation

Codecov / codecov/patch

src/otx/algo/segmentation/litehrnet.py#L609-L613

Added lines #L609 - L613 were not covered by tests

return {

Check warning on line 615 in src/otx/algo/segmentation/litehrnet.py

View check run for this annotation

Codecov / codecov/patch

src/otx/algo/segmentation/litehrnet.py#L615

Added line #L615 was not covered by tests
"inputs": labeled_inputs.images,
"unlabeled_weak_images": w_u_images,
"unlabeled_strong_images": s_u_images,
"global_step": self.trainer.global_step,
"steps_per_epoch": self.trainer.num_training_batches,
"img_metas": labeled_inputs.imgs_info,
"unlabeled_img_metas": unlabeled_img_metas,
"masks": labeled_inputs.masks,
"mode": "loss",
}
9 changes: 9 additions & 0 deletions src/otx/algo/segmentation/segmentors/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
"""Module for base NN segmentation models."""

from .base_model import BaseSegmModel
from .mean_teacher import MeanTeacher

__all__ = ["BaseSegmModel", "MeanTeacher"]
Loading
Loading