Skip to content

Commit

Permalink
Merge branch 'develop' into kp/add_semisl
Browse files Browse the repository at this point in the history
  • Loading branch information
kprokofi authored Aug 13, 2024
2 parents c267bbf + 962d26c commit dcad371
Show file tree
Hide file tree
Showing 30 changed files with 420 additions and 83 deletions.
45 changes: 15 additions & 30 deletions .github/dependabot.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,24 +24,19 @@ updates:
directory: /
schedule:
interval: daily
ignore:
- dependency-name: "torch"
- dependency-name: "torchvision"
- dependency-name: "lightning"
- dependency-name: "pytorchcv"
- dependency-name: "timm"
- dependency-name: "openvino*"
- dependency-name: "nncf"
- dependency-name: "anomalib"
- dependency-name: "intel-extension-for-pytorch"
- dependency-name: "oneccl_bind_pt"
groups:
pip-base-dependencies:
applies-to: version-updates
patterns:
- "torch"
- "lightning"
- "pytorchcv"
- "timm"
- "openvino"
- "openvino-dev"
- "openvino-model-api"
- "onnx"
- "onnxconverter-common"
- "nncf"
- "anomalib"
update-types:
- "patch"
pip-mmlab-dependencies:
pip-mmlab:
applies-to: version-updates
patterns:
- "mmdet"
Expand All @@ -52,20 +47,10 @@ updates:
- "oss2"
update-types:
- "patch"
pip-other-dependencies:
pip-others:
applies-to: version-updates
exclude-patterns:
- "torch"
- "lightning"
- "pytorchcv"
- "timm"
- "openvino"
- "openvino-dev"
- "openvino-model-api"
- "onnx"
- "onnxconverter-common"
- "nncf"
- "anomalib"
patterns:
- "*"
update-types:
- "minor"
- "patch"
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ All notable changes to this project will be documented in this file.
(https://github.com/openvinotoolkit/training_extensions/pull/3781)
- Add Semi-SL MeanTeacher algorithm for Semantic Segmentation
(https://github.com/openvinotoolkit/training_extensions/pull/3801)
- Update head and h-label format for hierarchical label classification
(https://github.com/openvinotoolkit/training_extensions/pull/3810)

### Enhancements

Expand Down
12 changes: 7 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -156,18 +156,16 @@ include = ["otx*"]

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# COVERAGE CONFIGURATION. #
[tool.coverage.paths]
source = [
"src",
]

[tool.coverage.report]
exclude_lines = [
"pragma: no cover",
"if TYPE_CHECKING:",
]

[tool.coverage.run]
source = [
"src/otx/",
]
omit = [
"**/__init__.py",
"src/otx/recipes/*",
Expand All @@ -184,6 +182,10 @@ omit = [
"src/otx/core/data/transform_libs/mmseg.py",
"src/otx/core/exporter/mmdeploy.py",
"src/otx/core/model/utils/*",

# Ignore some generated files by opencv-python
"config.py",
"config-3.py",
]


Expand Down
3 changes: 2 additions & 1 deletion src/otx/algo/classification/classifier/base_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def __init__(
neck: nn.Module | None,
head: nn.Module,
pretrained: str | None = None,
optimize_gap: bool = True,
mean: list[float] | None = None,
std: list[float] | None = None,
to_rgb: bool = False,
Expand All @@ -81,7 +82,7 @@ def __init__(
self.explainer = ReciproCAM(
self._head_forward_fn,
num_classes=head.num_classes,
optimize_gap=True,
optimize_gap=optimize_gap,
)

def forward(
Expand Down
7 changes: 4 additions & 3 deletions src/otx/algo/classification/efficientnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from otx.algo.classification.classifier.base_classifier import ImageClassifier
from otx.algo.classification.classifier.semi_sl_classifier import SemiSLClassifier
from otx.algo.classification.heads import (
HierarchicalLinearClsHead,
HierarchicalCBAMClsHead,
LinearClsHead,
MultiLabelLinearClsHead,
OTXSemiSLLinearClsHead,
Expand Down Expand Up @@ -265,13 +265,14 @@ def _build_model(self, head_config: dict) -> nn.Module:
backbone = OTXEfficientNet(version=self.version, pretrained=self.pretrained)
return ImageClassifier(
backbone=backbone,
neck=GlobalAveragePooling(dim=2),
head=HierarchicalLinearClsHead(
neck=nn.Identity(),
head=HierarchicalCBAMClsHead(
in_channels=backbone.num_features,
multiclass_loss=nn.CrossEntropyLoss(),
multilabel_loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"),
**head_config,
),
optimize_gap=False,
)

def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.") -> dict:
Expand Down
3 changes: 2 additions & 1 deletion src/otx/algo/classification/heads/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#
"""Head modules for OTX custom model."""

from .hlabel_cls_head import HierarchicalLinearClsHead, HierarchicalNonLinearClsHead
from .hlabel_cls_head import HierarchicalCBAMClsHead, HierarchicalLinearClsHead, HierarchicalNonLinearClsHead
from .linear_head import LinearClsHead
from .multilabel_cls_head import MultiLabelLinearClsHead, MultiLabelNonLinearClsHead
from .semi_sl_head import OTXSemiSLLinearClsHead, OTXSemiSLVisionTransformerClsHead
Expand All @@ -15,6 +15,7 @@
"MultiLabelNonLinearClsHead",
"HierarchicalLinearClsHead",
"HierarchicalNonLinearClsHead",
"HierarchicalCBAMClsHead",
"VisionTransformerClsHead",
"OTXSemiSLLinearClsHead",
"OTXSemiSLVisionTransformerClsHead",
Expand Down
133 changes: 133 additions & 0 deletions src/otx/algo/classification/heads/hlabel_cls_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,3 +353,136 @@ def forward(self, feats: tuple[torch.Tensor] | torch.Tensor) -> torch.Tensor:
"""The forward process."""
pre_logits = self.pre_logits(feats)
return self.classifier(pre_logits)


class ChannelAttention(nn.Module):
"""Channel attention module that uses average and max pooling to enhance important channels."""

def __init__(self, in_channels: int, reduction: int = 16):
"""Initializes the ChannelAttention module."""
super().__init__()
self.fc1 = nn.Conv2d(in_channels, in_channels // reduction, kernel_size=1, bias=False)
self.fc2 = nn.Conv2d(in_channels // reduction, in_channels, kernel_size=1, bias=False)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Applies channel attention to the input tensor."""
avg_out = self.fc2(torch.relu(self.fc1(torch.mean(x, dim=2, keepdim=True).mean(dim=3, keepdim=True))))
max_out = self.fc2(torch.relu(self.fc1(torch.max(x, dim=2, keepdim=True)[0].max(dim=3, keepdim=True)[0])))
return torch.sigmoid(avg_out + max_out)


class SpatialAttention(nn.Module):
"""Spatial attention module that uses average and max pooling to enhance important spatial locations."""

def __init__(self, kernel_size: int = 7):
"""Initializes the SpatialAttention module."""
super().__init__()
self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Applies spatial attention to the input tensor."""
avg_out = torch.mean(x, dim=1, keepdim=True)
max_out = torch.max(x, dim=1, keepdim=True)[0]
x = torch.cat([avg_out, max_out], dim=1)
return torch.sigmoid(self.conv(x))


class CBAM(nn.Module):
"""CBAM module that applies channel and spatial attention sequentially."""

def __init__(self, in_channels: int, reduction: int = 16, kernel_size: int = 7):
"""Initializes the CBAM module with channel and spatial attention."""
super().__init__()
self.channel_attention = ChannelAttention(in_channels, reduction)
self.spatial_attention = SpatialAttention(kernel_size)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Applies channel and spatial attention to the input tensor."""
x = x * self.channel_attention(x)
return x * self.spatial_attention(x)


class HierarchicalCBAMClsHead(HierarchicalClsHead):
"""Custom classification CBAM head for hierarchical classification task.
Args:
num_multiclass_heads (int): Number of multi-class heads.
num_multilabel_classes (int): Number of multi-label classes.
head_idx_to_logits_range (dict[str, tuple[int, int]]): the logit range of each heads
num_single_label_classes (int): the number of single label classes
empty_multiclass_head_indices (list[int]): the index of head that doesn't include any label
due to the label removing
in_channels (int): Number of channels in the input feature map.
num_classes (int): Number of total classes.
multiclass_loss (nn.Module): Config of multi-class loss.
multilabel_loss (nn.Module | None, optional): Config of multi-label loss.
thr (float, optional): Predictions with scores under the thresholds are considered
as negative. Defaults to 0.5.
init_cfg (dict | None, optional): Initialize configuration key-values, Defaults to None.
step_size (int, optional): Step size value for HierarchicalCBAMClsHead, Defaults to 7.
"""

def __init__(
self,
num_multiclass_heads: int,
num_multilabel_classes: int,
head_idx_to_logits_range: dict[str, tuple[int, int]],
num_single_label_classes: int,
empty_multiclass_head_indices: list[int],
in_channels: int,
num_classes: int,
multiclass_loss: nn.Module,
multilabel_loss: nn.Module | None = None,
thr: float = 0.5,
init_cfg: dict | None = None,
step_size: int = 7,
**kwargs,
):
super().__init__(
num_multiclass_heads=num_multiclass_heads,
num_multilabel_classes=num_multilabel_classes,
head_idx_to_logits_range=head_idx_to_logits_range,
num_single_label_classes=num_single_label_classes,
empty_multiclass_head_indices=empty_multiclass_head_indices,
in_channels=in_channels,
num_classes=num_classes,
multiclass_loss=multiclass_loss,
multilabel_loss=multilabel_loss,
thr=thr,
init_cfg=init_cfg,
**kwargs,
)
self.step_size = step_size
self.fc_superclass = nn.Linear(in_channels * step_size * step_size, num_multiclass_heads)
self.attention_fc = nn.Linear(num_multiclass_heads, in_channels * step_size * step_size)
self.cbam = CBAM(in_channels)
self.fc_subclass = nn.Linear(in_channels * step_size * step_size, num_single_label_classes)

self._init_layers()

def pre_logits(self, feats: tuple[torch.Tensor] | torch.Tensor) -> torch.Tensor:
"""The process before the final classification head."""
if isinstance(feats, Sequence):
feats = feats[-1]
return feats.view(feats.size(0), self.in_channels * self.step_size * self.step_size)

def _init_layers(self) -> None:
"""Iniitialize weights of classification head."""
normal_init(self.fc_superclass, mean=0, std=0.01, bias=0)
normal_init(self.fc_subclass, mean=0, std=0.01, bias=0)

def forward(self, feats: tuple[torch.Tensor] | torch.Tensor) -> torch.Tensor:
"""The forward process."""
pre_logits = self.pre_logits(feats)
out_superclass = self.fc_superclass(pre_logits)

attention_weights = torch.sigmoid(self.attention_fc(out_superclass))
attended_features = pre_logits * attention_weights

attended_features = attended_features.view(pre_logits.size(0), self.in_channels, self.step_size, self.step_size)
attended_features = self.cbam(attended_features)
attended_features = attended_features.view(
pre_logits.size(0),
self.in_channels * self.step_size * self.step_size,
)
return self.fc_subclass(attended_features)
7 changes: 4 additions & 3 deletions src/otx/algo/classification/mobilenet_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from otx.algo.classification.backbones import OTXMobileNetV3
from otx.algo.classification.classifier import ImageClassifier, SemiSLClassifier
from otx.algo.classification.heads import (
HierarchicalNonLinearClsHead,
HierarchicalCBAMClsHead,
LinearClsHead,
MultiLabelNonLinearClsHead,
OTXSemiSLLinearClsHead,
Expand Down Expand Up @@ -325,13 +325,14 @@ def _build_model(self, head_config: dict) -> nn.Module:

return ImageClassifier(
backbone=OTXMobileNetV3(mode=self.mode),
neck=GlobalAveragePooling(dim=2),
head=HierarchicalNonLinearClsHead(
neck=nn.Identity(),
head=HierarchicalCBAMClsHead(
in_channels=960,
multiclass_loss=nn.CrossEntropyLoss(),
multilabel_loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"),
**head_config,
),
optimize_gap=False,
)

def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.") -> dict:
Expand Down
7 changes: 4 additions & 3 deletions src/otx/algo/classification/timm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from otx.algo.classification.backbones.timm import TimmBackbone, TimmModelType
from otx.algo.classification.classifier import ImageClassifier, SemiSLClassifier
from otx.algo.classification.heads import (
HierarchicalLinearClsHead,
HierarchicalCBAMClsHead,
LinearClsHead,
MultiLabelLinearClsHead,
OTXSemiSLLinearClsHead,
Expand Down Expand Up @@ -264,13 +264,14 @@ def _build_model(self, head_config: dict) -> nn.Module:
backbone = TimmBackbone(backbone=self.backbone, pretrained=self.pretrained)
return ImageClassifier(
backbone=backbone,
neck=GlobalAveragePooling(dim=2),
head=HierarchicalLinearClsHead(
neck=nn.Identity(),
head=HierarchicalCBAMClsHead(
in_channels=backbone.num_features,
multiclass_loss=nn.CrossEntropyLoss(),
multilabel_loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"),
**head_config,
),
optimize_gap=False,
)

def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.") -> dict:
Expand Down
7 changes: 4 additions & 3 deletions src/otx/algo/classification/torchvision_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from torch import Tensor, nn
from torchvision.models import get_model, get_model_weights

from otx.algo.classification.heads import HierarchicalLinearClsHead, MultiLabelLinearClsHead, OTXSemiSLLinearClsHead
from otx.algo.classification.heads import HierarchicalCBAMClsHead, MultiLabelLinearClsHead, OTXSemiSLLinearClsHead
from otx.algo.classification.losses import AsymmetricAngularLossWithIgnore
from otx.algo.explain.explain_algo import ReciproCAM, feature_vector_fn
from otx.core.data.entity.base import OTXBatchLossEntity
Expand Down Expand Up @@ -209,12 +209,13 @@ def _get_head(self, net: nn.Module) -> nn.Module:
loss=self.loss_module,
)
if self.task == OTXTaskType.H_LABEL_CLS:
self.neck = nn.Sequential(*layers) if layers else None
return HierarchicalLinearClsHead(
self.neck = nn.Sequential(*layers, nn.Identity()) if layers else None
return HierarchicalCBAMClsHead(
in_channels=feature_channel,
multiclass_loss=nn.CrossEntropyLoss(),
multilabel_loss=self.loss_module,
**self.head_config,
step_size=1,
)

msg = f"Task type {self.task} is not supported."
Expand Down
Loading

0 comments on commit dcad371

Please sign in to comment.