Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update head and h-label format for hierarchical label classification #3810

Merged
Merged
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ All notable changes to this project will be documented in this file.
(https://github.com/openvinotoolkit/training_extensions/pull/3762)
- Add RTMPose for Keypoint Detection Task
(https://github.com/openvinotoolkit/training_extensions/pull/3781)
- Update head and h-label format for hierarchical label classification
(https://github.com/openvinotoolkit/training_extensions/pull/3810)

### Enhancements

Expand Down
3 changes: 2 additions & 1 deletion src/otx/algo/classification/classifier/base_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def __init__(
neck: nn.Module | None,
head: nn.Module,
pretrained: str | None = None,
optimize_gap: bool = True,
eunwoosh marked this conversation as resolved.
Show resolved Hide resolved
mean: list[float] | None = None,
std: list[float] | None = None,
to_rgb: bool = False,
Expand All @@ -81,7 +82,7 @@ def __init__(
self.explainer = ReciproCAM(
self._head_forward_fn,
num_classes=head.num_classes,
optimize_gap=True,
optimize_gap=optimize_gap,
)

def forward(
Expand Down
7 changes: 4 additions & 3 deletions src/otx/algo/classification/efficientnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from otx.algo.classification.classifier.base_classifier import ImageClassifier
from otx.algo.classification.classifier.semi_sl_classifier import SemiSLClassifier
from otx.algo.classification.heads import (
HierarchicalLinearClsHead,
HierarchicalCBAMClsHead,
LinearClsHead,
MultiLabelLinearClsHead,
OTXSemiSLLinearClsHead,
Expand Down Expand Up @@ -265,13 +265,14 @@ def _build_model(self, head_config: dict) -> nn.Module:
backbone = OTXEfficientNet(version=self.version, pretrained=self.pretrained)
return ImageClassifier(
backbone=backbone,
neck=GlobalAveragePooling(dim=2),
head=HierarchicalLinearClsHead(
neck=nn.Identity(),
head=HierarchicalCBAMClsHead(
in_channels=backbone.num_features,
multiclass_loss=nn.CrossEntropyLoss(),
multilabel_loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"),
**head_config,
),
optimize_gap=False,
)

def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.") -> dict:
Expand Down
3 changes: 2 additions & 1 deletion src/otx/algo/classification/heads/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#
"""Head modules for OTX custom model."""

from .hlabel_cls_head import HierarchicalLinearClsHead, HierarchicalNonLinearClsHead
from .hlabel_cls_head import HierarchicalCBAMClsHead, HierarchicalLinearClsHead, HierarchicalNonLinearClsHead
from .linear_head import LinearClsHead
from .multilabel_cls_head import MultiLabelLinearClsHead, MultiLabelNonLinearClsHead
from .semi_sl_head import OTXSemiSLLinearClsHead, OTXSemiSLVisionTransformerClsHead
Expand All @@ -15,6 +15,7 @@
"MultiLabelNonLinearClsHead",
"HierarchicalLinearClsHead",
"HierarchicalNonLinearClsHead",
"HierarchicalCBAMClsHead",
"VisionTransformerClsHead",
"OTXSemiSLLinearClsHead",
"OTXSemiSLVisionTransformerClsHead",
Expand Down
133 changes: 133 additions & 0 deletions src/otx/algo/classification/heads/hlabel_cls_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,3 +353,136 @@ def forward(self, feats: tuple[torch.Tensor] | torch.Tensor) -> torch.Tensor:
"""The forward process."""
pre_logits = self.pre_logits(feats)
return self.classifier(pre_logits)


class ChannelAttention(nn.Module):
"""Channel attention module that uses average and max pooling to enhance important channels."""

def __init__(self, in_channels: int, reduction: int = 16):
"""Initializes the ChannelAttention module."""
super().__init__()
self.fc1 = nn.Conv2d(in_channels, in_channels // reduction, kernel_size=1, bias=False)
self.fc2 = nn.Conv2d(in_channels // reduction, in_channels, kernel_size=1, bias=False)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Applies channel attention to the input tensor."""
avg_out = self.fc2(torch.relu(self.fc1(torch.mean(x, dim=2, keepdim=True).mean(dim=3, keepdim=True))))
max_out = self.fc2(torch.relu(self.fc1(torch.max(x, dim=2, keepdim=True)[0].max(dim=3, keepdim=True)[0])))
return torch.sigmoid(avg_out + max_out)


class SpatialAttention(nn.Module):
"""Spatial attention module that uses average and max pooling to enhance important spatial locations."""

def __init__(self, kernel_size: int = 7):
"""Initializes the SpatialAttention module."""
super().__init__()
self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Applies spatial attention to the input tensor."""
avg_out = torch.mean(x, dim=1, keepdim=True)
max_out = torch.max(x, dim=1, keepdim=True)[0]
x = torch.cat([avg_out, max_out], dim=1)
return torch.sigmoid(self.conv(x))


class CBAM(nn.Module):
"""CBAM module that applies channel and spatial attention sequentially."""

def __init__(self, in_channels: int, reduction: int = 16, kernel_size: int = 7):
"""Initializes the CBAM module with channel and spatial attention."""
super().__init__()
self.channel_attention = ChannelAttention(in_channels, reduction)
self.spatial_attention = SpatialAttention(kernel_size)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Applies channel and spatial attention to the input tensor."""
x = x * self.channel_attention(x)
return x * self.spatial_attention(x)


class HierarchicalCBAMClsHead(HierarchicalClsHead):
"""Custom classification CBAM head for hierarchical classification task.

Args:
num_multiclass_heads (int): Number of multi-class heads.
num_multilabel_classes (int): Number of multi-label classes.
head_idx_to_logits_range (dict[str, tuple[int, int]]): the logit range of each heads
num_single_label_classes (int): the number of single label classes
empty_multiclass_head_indices (list[int]): the index of head that doesn't include any label
due to the label removing
in_channels (int): Number of channels in the input feature map.
num_classes (int): Number of total classes.
multiclass_loss (nn.Module): Config of multi-class loss.
multilabel_loss (nn.Module | None, optional): Config of multi-label loss.
thr (float, optional): Predictions with scores under the thresholds are considered
as negative. Defaults to 0.5.
init_cfg (dict | None, optional): Initialize configuration key-values, Defaults to None.
step_size (int, optional): Step size value for HierarchicalCBAMClsHead, Defaults to 7.
"""

def __init__(
self,
num_multiclass_heads: int,
num_multilabel_classes: int,
head_idx_to_logits_range: dict[str, tuple[int, int]],
num_single_label_classes: int,
empty_multiclass_head_indices: list[int],
in_channels: int,
num_classes: int,
multiclass_loss: nn.Module,
multilabel_loss: nn.Module | None = None,
thr: float = 0.5,
init_cfg: dict | None = None,
step_size: int = 7,
**kwargs,
):
super().__init__(
num_multiclass_heads=num_multiclass_heads,
num_multilabel_classes=num_multilabel_classes,
head_idx_to_logits_range=head_idx_to_logits_range,
num_single_label_classes=num_single_label_classes,
empty_multiclass_head_indices=empty_multiclass_head_indices,
in_channels=in_channels,
num_classes=num_classes,
multiclass_loss=multiclass_loss,
multilabel_loss=multilabel_loss,
thr=thr,
init_cfg=init_cfg,
**kwargs,
)
self.step_size = step_size
self.fc_superclass = nn.Linear(in_channels * step_size * step_size, num_multiclass_heads)
self.attention_fc = nn.Linear(num_multiclass_heads, in_channels * step_size * step_size)
self.cbam = CBAM(in_channels)
self.fc_subclass = nn.Linear(in_channels * step_size * step_size, num_single_label_classes)

self._init_layers()

def pre_logits(self, feats: tuple[torch.Tensor] | torch.Tensor) -> torch.Tensor:
"""The process before the final classification head."""
if isinstance(feats, Sequence):
feats = feats[-1]
return feats.view(feats.size(0), self.in_channels * self.step_size * self.step_size)

def _init_layers(self) -> None:
"""Iniitialize weights of classification head."""
normal_init(self.fc_superclass, mean=0, std=0.01, bias=0)
normal_init(self.fc_subclass, mean=0, std=0.01, bias=0)

def forward(self, feats: tuple[torch.Tensor] | torch.Tensor) -> torch.Tensor:
"""The forward process."""
pre_logits = self.pre_logits(feats)
out_superclass = self.fc_superclass(pre_logits)

attention_weights = torch.sigmoid(self.attention_fc(out_superclass))
attended_features = pre_logits * attention_weights

attended_features = attended_features.view(pre_logits.size(0), self.in_channels, self.step_size, self.step_size)
attended_features = self.cbam(attended_features)
attended_features = attended_features.view(
pre_logits.size(0),
self.in_channels * self.step_size * self.step_size,
)
return self.fc_subclass(attended_features)
7 changes: 4 additions & 3 deletions src/otx/algo/classification/mobilenet_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from otx.algo.classification.backbones import OTXMobileNetV3
from otx.algo.classification.classifier import ImageClassifier, SemiSLClassifier
from otx.algo.classification.heads import (
HierarchicalNonLinearClsHead,
HierarchicalCBAMClsHead,
LinearClsHead,
MultiLabelNonLinearClsHead,
OTXSemiSLLinearClsHead,
Expand Down Expand Up @@ -325,13 +325,14 @@ def _build_model(self, head_config: dict) -> nn.Module:

return ImageClassifier(
backbone=OTXMobileNetV3(mode=self.mode),
neck=GlobalAveragePooling(dim=2),
head=HierarchicalNonLinearClsHead(
neck=nn.Identity(),
head=HierarchicalCBAMClsHead(
in_channels=960,
multiclass_loss=nn.CrossEntropyLoss(),
multilabel_loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"),
**head_config,
),
optimize_gap=False,
)

def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.") -> dict:
Expand Down
7 changes: 4 additions & 3 deletions src/otx/algo/classification/timm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from otx.algo.classification.backbones.timm import TimmBackbone, TimmModelType
from otx.algo.classification.classifier import ImageClassifier, SemiSLClassifier
from otx.algo.classification.heads import (
HierarchicalLinearClsHead,
HierarchicalCBAMClsHead,
LinearClsHead,
MultiLabelLinearClsHead,
OTXSemiSLLinearClsHead,
Expand Down Expand Up @@ -264,13 +264,14 @@ def _build_model(self, head_config: dict) -> nn.Module:
backbone = TimmBackbone(backbone=self.backbone, pretrained=self.pretrained)
return ImageClassifier(
backbone=backbone,
neck=GlobalAveragePooling(dim=2),
head=HierarchicalLinearClsHead(
neck=nn.Identity(),
head=HierarchicalCBAMClsHead(
in_channels=backbone.num_features,
multiclass_loss=nn.CrossEntropyLoss(),
multilabel_loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"),
**head_config,
),
optimize_gap=False,
)

def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.") -> dict:
Expand Down
7 changes: 4 additions & 3 deletions src/otx/algo/classification/torchvision_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from torch import Tensor, nn
from torchvision.models import get_model, get_model_weights

from otx.algo.classification.heads import HierarchicalLinearClsHead, MultiLabelLinearClsHead, OTXSemiSLLinearClsHead
from otx.algo.classification.heads import HierarchicalCBAMClsHead, MultiLabelLinearClsHead, OTXSemiSLLinearClsHead
from otx.algo.classification.losses import AsymmetricAngularLossWithIgnore
from otx.algo.explain.explain_algo import ReciproCAM, feature_vector_fn
from otx.core.data.entity.base import OTXBatchLossEntity
Expand Down Expand Up @@ -209,12 +209,13 @@ def _get_head(self, net: nn.Module) -> nn.Module:
loss=self.loss_module,
)
if self.task == OTXTaskType.H_LABEL_CLS:
self.neck = nn.Sequential(*layers) if layers else None
return HierarchicalLinearClsHead(
self.neck = nn.Sequential(*layers, nn.Identity()) if layers else None
return HierarchicalCBAMClsHead(
in_channels=feature_channel,
multiclass_loss=nn.CrossEntropyLoss(),
multilabel_loss=self.loss_module,
**self.head_config,
step_size=1,
)

msg = f"Task type {self.task} is not supported."
Expand Down
6 changes: 4 additions & 2 deletions src/otx/algo/classification/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from otx.algo.classification.backbones.vision_transformer import VIT_ARCH_TYPE, VisionTransformer
from otx.algo.classification.classifier import ImageClassifier, SemiSLClassifier
from otx.algo.classification.heads import (
HierarchicalLinearClsHead,
HierarchicalCBAMClsHead,
MultiLabelLinearClsHead,
OTXSemiSLVisionTransformerClsHead,
VisionTransformerClsHead,
Expand Down Expand Up @@ -494,11 +494,13 @@ def _build_model(self, head_config: dict) -> nn.Module:
return ImageClassifier(
backbone=vit_backbone,
neck=None,
head=HierarchicalLinearClsHead(
head=HierarchicalCBAMClsHead(
in_channels=vit_backbone.embed_dim,
multiclass_loss=nn.CrossEntropyLoss(),
multilabel_loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"),
step_size=1,
**head_config,
),
sooahleex marked this conversation as resolved.
Show resolved Hide resolved
optimize_gap=False,
init_cfg=init_cfg,
)
23 changes: 15 additions & 8 deletions src/otx/core/data/dataset/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def _get_label_group_idx(label_name: str) -> int:

def _find_ancestor_recursively(label_name: str, ancestors: list) -> list[str]:
_, dm_label_category = self.dm_categories.find(label_name)
parent_name = dm_label_category.parent
parent_name = dm_label_category.parent if dm_label_category else ""

if parent_name != "":
ancestors.append(parent_name)
Expand Down Expand Up @@ -200,21 +200,22 @@ def _convert_label_to_hlabel_format(self, label_anns: list[Label], ignored_label
"""Convert format of the label to the h-label.

It converts the label format to h-label format.
Total length of result is sum of number of hierarchy and number of multilabel classes.

i.e.
Let's assume that we used the same dataset with example of the definition of HLabelData
and the original labels are ["Rigid", "Panda", "Lion"].
and the original labels are ["Rigid", "Triangle", "Lion"].

Then, h-label format will be [1, -1, 0, 1, 1].
Then, h-label format will be [0, 1, 1, 0].
The first N-th indices represent the label index of multiclass heads (N=num_multiclass_heads),
others represent the multilabel labels.

[Multiclass Heads: [1, -1]]
0-th index = 1 -> ["Non-Rigid"(X), "Rigid"(O)] <- First multiclass head
1-st index = -1 -> ["Rectangle"(X), "Triangle"(X)] <- Second multiclass head
[Multiclass Heads]
0-th index = 0 -> ["Rigid"(O), "Non-Rigid"(X)] <- First multiclass head
1-st index = 1 -> ["Rectangle"(O), "Triangle"(X), "Circle"(X)] <- Second multiclass head

[Multilabel Head: [0, 1, 1]]
2, 3, 4 indices = [0, 1, 1] -> ["Circle"(X), "Lion"(O), "Panda"(O)]
[Multilabel Head]
2, 3 indices = [1, 0] -> ["Lion"(O), "Panda"(X)]
"""
if not isinstance(self.label_info, HLabelInfo):
msg = f"The type of label_info should be HLabelInfo, got {type(self.label_info)}."
Expand All @@ -229,10 +230,16 @@ def _convert_label_to_hlabel_format(self, label_anns: list[Label], ignored_label

for ann in label_anns:
ann_name = self.dm_categories.items[ann.label].name
ann_parent = self.dm_categories.items[ann.label].parent
group_idx, in_group_idx = self.label_info.class_to_group_idx[ann_name]
(parent_group_idx, parent_in_group_idx) = (
self.label_info.class_to_group_idx[ann_parent] if ann_parent else (None, None)
)

if group_idx < num_multiclass_heads:
class_indices[group_idx] = in_group_idx
if parent_group_idx is not None and parent_in_group_idx is not None:
class_indices[parent_group_idx] = parent_in_group_idx
elif not ignored_labels or ann.label not in ignored_labels:
class_indices[num_multiclass_heads + in_group_idx] = 1
else:
Expand Down
Loading
Loading