From bd4745d6394d88ece7e569abfe022308f2c41625 Mon Sep 17 00:00:00 2001 From: sooahleex Date: Wed, 31 Jul 2024 15:55:19 +0900 Subject: [PATCH 01/16] Add CBAM head for h-label cls --- src/otx/algo/classification/heads/__init__.py | 3 +- .../classification/heads/hlabel_cls_head.py | 135 ++++++++++++++++++ src/otx/core/data/dataset/classification.py | 3 +- .../heads/test_hlabel_cls_head.py | 8 +- 4 files changed, 145 insertions(+), 4 deletions(-) diff --git a/src/otx/algo/classification/heads/__init__.py b/src/otx/algo/classification/heads/__init__.py index aea5e1f0f4c..435688be314 100644 --- a/src/otx/algo/classification/heads/__init__.py +++ b/src/otx/algo/classification/heads/__init__.py @@ -3,7 +3,7 @@ # """Head modules for OTX custom model.""" -from .hlabel_cls_head import HierarchicalLinearClsHead, HierarchicalNonLinearClsHead +from .hlabel_cls_head import HierarchicalLinearClsHead, HierarchicalNonLinearClsHead, HierarchicalCBAMClsHead from .linear_head import LinearClsHead from .multilabel_cls_head import MultiLabelLinearClsHead, MultiLabelNonLinearClsHead from .semi_sl_head import OTXSemiSLLinearClsHead, OTXSemiSLVisionTransformerClsHead @@ -15,6 +15,7 @@ "MultiLabelNonLinearClsHead", "HierarchicalLinearClsHead", "HierarchicalNonLinearClsHead", + "HierarchicalCBAMClsHead", "VisionTransformerClsHead", "OTXSemiSLLinearClsHead", "OTXSemiSLVisionTransformerClsHead", diff --git a/src/otx/algo/classification/heads/hlabel_cls_head.py b/src/otx/algo/classification/heads/hlabel_cls_head.py index b976b83642f..2a4289e9805 100644 --- a/src/otx/algo/classification/heads/hlabel_cls_head.py +++ b/src/otx/algo/classification/heads/hlabel_cls_head.py @@ -353,3 +353,138 @@ def forward(self, feats: tuple[torch.Tensor] | torch.Tensor) -> torch.Tensor: """The forward process.""" pre_logits = self.pre_logits(feats) return self.classifier(pre_logits) + +class ChannelAttention(nn.Module): + def __init__(self, in_channels, reduction=16): + super(ChannelAttention, self).__init__() + self.fc1 = nn.Conv2d(in_channels, in_channels // reduction, kernel_size=1, bias=False) + self.fc2 = nn.Conv2d(in_channels // reduction, in_channels, kernel_size=1, bias=False) + + def forward(self, x): + avg_out = self.fc2(torch.relu(self.fc1(torch.mean(x, dim=2, keepdim=True).mean(dim=3, keepdim=True)))) + max_out = self.fc2(torch.relu(self.fc1(torch.max(x, dim=2, keepdim=True)[0].max(dim=3, keepdim=True)[0]))) + return torch.sigmoid(avg_out + max_out) + +class SpatialAttention(nn.Module): + def __init__(self, kernel_size=7): + super(SpatialAttention, self).__init__() + self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False) + + def forward(self, x): + avg_out = torch.mean(x, dim=1, keepdim=True) + max_out = torch.max(x, dim=1, keepdim=True)[0] + x = torch.cat([avg_out, max_out], dim=1) + return torch.sigmoid(self.conv(x)) + +class CBAM(nn.Module): + def __init__(self, in_channels, reduction=16, kernel_size=7): + super(CBAM, self).__init__() + self.channel_attention = ChannelAttention(in_channels, reduction) + self.spatial_attention = SpatialAttention(kernel_size) + + def forward(self, x): + x = x * self.channel_attention(x) + x = x * self.spatial_attention(x) + return x + +class HierarchicalCBAMClsHead(HierarchicalClsHead): + """Custom classification CBAM head for hierarchical classification task. + + Args: + num_multiclass_heads (int): Number of multi-class heads. + num_multilabel_classes (int): Number of multi-label classes. + head_idx_to_logits_range: the logit range of each heads + num_single_label_classes: the number of single label classes + empty_multiclass_head_indices: the index of head that doesn't include any label + due to the label removing + in_channels (int): Number of channels in the input feature map. + num_classes (int): Number of total classes. + multiclass_loss (dict | None): Config of multi-class loss. + multilabel_loss (dict | None): Config of multi-label loss. + thr (float | None): Predictions with scores under the thresholds are considered + as negative. Defaults to 0.5. + hid_cahnnels (int): Number of channels in the hidden feature map at the classifier. + acivation_Cfg (dict | None): Config of activation layer at the classifier. + dropout (bool): Flag for the enabling the dropout at the classifier. + + """ + + def __init__( + self, + num_multiclass_heads: int, + num_multilabel_classes: int, + head_idx_to_logits_range: dict[str, tuple[int, int]], + num_single_label_classes: int, + empty_multiclass_head_indices: list[int], + in_channels: int, + num_classes: int, + multiclass_loss: nn.Module, + multilabel_loss: nn.Module | None = None, + thr: float = 0.5, + hid_channels: int = 1280, + activation_callable: Callable[[], nn.Module] = nn.ReLU, + dropout: bool = False, + init_cfg: dict | None = None, + **kwargs, + ): + super().__init__( + num_multiclass_heads=num_multiclass_heads, + num_multilabel_classes=num_multilabel_classes, + head_idx_to_logits_range=head_idx_to_logits_range, + num_single_label_classes=num_single_label_classes, + empty_multiclass_head_indices=empty_multiclass_head_indices, + in_channels=in_channels, + num_classes=num_classes, + multiclass_loss=multiclass_loss, + multilabel_loss=multilabel_loss, + thr=thr, + init_cfg=init_cfg, + **kwargs, + ) + self.hid_channels = hid_channels + self.dropout = dropout + + self.activation_callable = activation_callable + + self.fc_superclass = nn.Linear(in_channels, num_multiclass_heads) + self.attention_fc = nn.Linear(num_multiclass_heads, in_channels) + self.cbam = CBAM(in_channels) + + classifier_modules = [ + nn.Linear(in_channels, hid_channels), + nn.BatchNorm1d(hid_channels), + self.activation_callable if isinstance(self.activation_callable, nn.Module) else self.activation_callable(), + ] + + if self.dropout: + classifier_modules.append(nn.Dropout(p=0.2)) + + classifier_modules.append(nn.Linear(hid_channels, num_classes)) + + self.fc_subclass = nn.Sequential(*classifier_modules) + + self._init_layers() + + def _init_layers(self) -> None: + """Iniitialize weights of classification head.""" + normal_init(self.fc_superclass, mean=0, std=0.01, bias=0) + for module in self.fc_subclass: + if isinstance(module, nn.Linear): + normal_init(module, mean=0, std=0.01, bias=0) + elif isinstance(module, nn.BatchNorm1d): + constant_init(module, 1) + + def forward(self, feats: tuple[torch.Tensor] | torch.Tensor) -> torch.Tensor: + """The forward process.""" + pre_logits = self.pre_logits(feats) + out_superclass = self.fc_superclass(pre_logits) + + attention_weights = torch.sigmoid(self.attention_fc(out_superclass)) + attended_features = pre_logits * attention_weights + + attended_features = attended_features.view(pre_logits.size(0), self.in_channels, 1, 1) + attended_features = self.cbam(attended_features) + attended_features = attended_features.view(pre_logits.size(0), self.in_channels) + out_subclass = self.fc_subclass(attended_features) + + return out_subclass diff --git a/src/otx/core/data/dataset/classification.py b/src/otx/core/data/dataset/classification.py index a26ffc2799e..1d0b82524cb 100644 --- a/src/otx/core/data/dataset/classification.py +++ b/src/otx/core/data/dataset/classification.py @@ -145,7 +145,8 @@ def _get_label_group_idx(label_name: str) -> int: def _find_ancestor_recursively(label_name: str, ancestors: list) -> list[str]: _, dm_label_category = self.dm_categories.find(label_name) - parent_name = dm_label_category.parent + # parent_name = dm_label_category.parent + parent_name = dm_label_category.parent if dm_label_category else "" if parent_name != "": ancestors.append(parent_name) diff --git a/tests/unit/algo/classification/heads/test_hlabel_cls_head.py b/tests/unit/algo/classification/heads/test_hlabel_cls_head.py index 0b7880b9d96..4511ac7fc49 100644 --- a/tests/unit/algo/classification/heads/test_hlabel_cls_head.py +++ b/tests/unit/algo/classification/heads/test_hlabel_cls_head.py @@ -8,7 +8,7 @@ import pytest import torch -from otx.algo.classification.heads import HierarchicalLinearClsHead, HierarchicalNonLinearClsHead +from otx.algo.classification.heads import HierarchicalLinearClsHead, HierarchicalNonLinearClsHead, HierarchicalCBAMClsHead from otx.algo.classification.losses import AsymmetricAngularLossWithIgnore from otx.core.data.entity.base import ImageInfo from torch import nn @@ -65,7 +65,11 @@ def fxt_hlabel_linear_head(self, fxt_head_attrs) -> nn.Module: def fxt_hlabel_non_linear_head(self, fxt_head_attrs) -> nn.Module: return HierarchicalNonLinearClsHead(**fxt_head_attrs) - @pytest.fixture(params=["fxt_hlabel_linear_head", "fxt_hlabel_non_linear_head"]) + @pytest.fixture() + def fxt_hlabel_cbam_head(self, fxt_head_attrs) -> nn.Module: + return HierarchicalCBAMClsHead(**fxt_head_attrs) + + @pytest.fixture(params=["fxt_hlabel_linear_head", "fxt_hlabel_non_linear_head", "fxt_hlabel_cbam_head"]) def fxt_hlabel_head(self, request) -> nn.Module: return request.getfixturevalue(request.param) From 2a63e4ac5c16d7fdba24260c396a73ef2f77996f Mon Sep 17 00:00:00 2001 From: sooahleex Date: Wed, 31 Jul 2024 16:37:48 +0900 Subject: [PATCH 02/16] Update loss using label smoothing --- .../algo/classification/efficientnet_v2.py | 539 ++++++++++++++++++ .../classification/heads/hlabel_cls_head.py | 2 +- 2 files changed, 540 insertions(+), 1 deletion(-) create mode 100644 src/otx/algo/classification/efficientnet_v2.py diff --git a/src/otx/algo/classification/efficientnet_v2.py b/src/otx/algo/classification/efficientnet_v2.py new file mode 100644 index 00000000000..e42af4db7ad --- /dev/null +++ b/src/otx/algo/classification/efficientnet_v2.py @@ -0,0 +1,539 @@ +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +"""EfficientNetV2 model implementation.""" +from __future__ import annotations + +from copy import deepcopy +from typing import TYPE_CHECKING, Any + +import torch +from torch import Tensor, nn + +from otx.algo.classification.backbones.timm import TimmBackbone +from otx.algo.classification.classifier import ImageClassifier, SemiSLClassifier +from otx.algo.classification.heads import ( + HierarchicalLinearClsHead, + LinearClsHead, + MultiLabelLinearClsHead, + OTXSemiSLLinearClsHead, + HierarchicalCBAMClsHead, +) +from otx.algo.classification.losses.asymmetric_angular_loss_with_ignore import AsymmetricAngularLossWithIgnore +from otx.algo.classification.necks.gap import GlobalAveragePooling +from otx.algo.classification.utils import get_classification_layers +from otx.algo.utils.support_otx_v1 import OTXv1Helper +from otx.core.data.entity.base import OTXBatchLossEntity +from otx.core.data.entity.classification import ( + HlabelClsBatchDataEntity, + HlabelClsBatchPredEntity, + MulticlassClsBatchDataEntity, + MulticlassClsBatchPredEntity, + MultilabelClsBatchDataEntity, + MultilabelClsBatchPredEntity, +) +from otx.core.exporter.base import OTXModelExporter +from otx.core.exporter.native import OTXNativeModelExporter +from otx.core.metrics import MetricInput +from otx.core.metrics.accuracy import HLabelClsMetricCallble, MultiClassClsMetricCallable, MultiLabelClsMetricCallable +from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable +from otx.core.model.classification import ( + OTXHlabelClsModel, + OTXMulticlassClsModel, + OTXMultilabelClsModel, +) +from otx.core.schedulers import LRSchedulerListCallable +from otx.core.types.label import HLabelInfo, LabelInfoTypes + +if TYPE_CHECKING: + from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable + + from otx.core.metrics import MetricCallable + + +class EfficientNetV2ForMulticlassCls(OTXMulticlassClsModel): + """EfficientNetV2 Model for multi-class classification task.""" + + def __init__( + self, + label_info: LabelInfoTypes, + optimizer: OptimizerCallable = DefaultOptimizerCallable, + scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable, + metric: MetricCallable = MultiClassClsMetricCallable, + torch_compile: bool = False, + ) -> None: + super().__init__( + label_info=label_info, + optimizer=optimizer, + scheduler=scheduler, + metric=metric, + torch_compile=torch_compile, + ) + + def _create_model(self) -> nn.Module: + # Get classification_layers for class-incr learning + sample_model_dict = self._build_model(num_classes=5).state_dict() + incremental_model_dict = self._build_model(num_classes=6).state_dict() + self.classification_layers = get_classification_layers( + sample_model_dict, + incremental_model_dict, + prefix="model.", + ) + + model = self._build_model(num_classes=self.num_classes) + model.init_weights() + return model + + def _build_model(self, num_classes: int) -> nn.Module: + return ImageClassifier( + backbone=TimmBackbone(backbone="efficientnetv2_s_21k", pretrained=True), + neck=GlobalAveragePooling(dim=2), + head=LinearClsHead( + num_classes=num_classes, + in_channels=1280, + topk=(1, 5) if num_classes >= 5 else (1,), + loss=nn.CrossEntropyLoss(reduction="none"), + ), + ) + + def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.") -> dict: + """Load the previous OTX ckpt according to OTX2.0.""" + return OTXv1Helper.load_cls_effnet_v2_ckpt(state_dict, "multiclass", add_prefix) + + def _customize_inputs(self, inputs: MulticlassClsBatchDataEntity) -> dict[str, Any]: + if self.training: + mode = "loss" + elif self.explain_mode: + mode = "explain" + else: + mode = "predict" + + return { + "images": inputs.stacked_images, + "labels": torch.cat(inputs.labels, dim=0), + "imgs_info": inputs.imgs_info, + "mode": mode, + } + + def _customize_outputs( + self, + outputs: Any, # noqa: ANN401 + inputs: MulticlassClsBatchDataEntity, + ) -> MulticlassClsBatchPredEntity | OTXBatchLossEntity: + if self.training: + return OTXBatchLossEntity(loss=outputs) + + # To list, batch-wise + logits = outputs if isinstance(outputs, torch.Tensor) else outputs["logits"] + scores = torch.unbind(logits, 0) + preds = logits.argmax(-1, keepdim=True).unbind(0) + + return MulticlassClsBatchPredEntity( + batch_size=inputs.batch_size, + images=inputs.images, + imgs_info=inputs.imgs_info, + scores=scores, + labels=preds, + ) + + @property + def _exporter(self) -> OTXModelExporter: + """Creates OTXModelExporter object that can export the model.""" + return OTXNativeModelExporter( + task_level_export_parameters=self._export_parameters, + input_size=self.image_size, + mean=(123.675, 116.28, 103.53), + std=(58.395, 57.12, 57.375), + resize_mode="standard", + pad_value=0, + swap_rgb=False, + via_onnx=False, + onnx_export_configuration=None, + output_names=["logits", "feature_vector", "saliency_map"] if self.explain_mode else None, + ) + + def forward_explain(self, inputs: MulticlassClsBatchDataEntity) -> MulticlassClsBatchPredEntity: + """Model forward explain function.""" + outputs = self.model(images=inputs.stacked_images, mode="explain") + + return MulticlassClsBatchPredEntity( + batch_size=len(outputs["preds"]), + images=inputs.images, + imgs_info=inputs.imgs_info, + labels=outputs["preds"], + scores=outputs["scores"], + saliency_map=outputs["saliency_map"], + feature_vector=outputs["feature_vector"], + ) + + def forward_for_tracing(self, image: torch.Tensor) -> torch.Tensor | dict[str, torch.Tensor]: + """Model forward function used for the model tracing during model exportation.""" + if self.explain_mode: + return self.model(images=image, mode="explain") + + return self.model(images=image, mode="tensor") + + +class EfficientNetV2ForMulticlassClsSemiSL(EfficientNetV2ForMulticlassCls): + """EfficientNetV2 model for multiclass classification with semi-supervised learning. + + This class extends the `EfficientNetV2ForMulticlassCls` class and adds support for semi-supervised learning. + It overrides the `_build_model` and `_customize_inputs` methods to incorporate the semi-supervised learning. + + Args: + EfficientNetV2ForMulticlassCls (class): The base class for EfficientNetV2 multiclass classification. + """ + + def _build_model(self, num_classes: int) -> nn.Module: + return SemiSLClassifier( + backbone=TimmBackbone(backbone="efficientnetv2_s_21k", pretrained=True), + neck=GlobalAveragePooling(dim=2), + head=OTXSemiSLLinearClsHead( + num_classes=num_classes, + in_channels=1280, + loss=nn.CrossEntropyLoss(reduction="none"), + ), + ) + + def _customize_inputs(self, inputs: MulticlassClsBatchDataEntity) -> dict[str, Any]: + """Customizes the input data for the model based on the current mode. + + Args: + inputs (MulticlassClsBatchDataEntity): The input batch of data. + + Returns: + dict[str, Any]: The customized input data. + """ + if self.training: + mode = "loss" + elif self.explain_mode: + mode = "explain" + else: + mode = "predict" + + if isinstance(inputs, dict): + # When used with an unlabeled dataset, it comes in as a dict. + images = {key: inputs[key].images for key in inputs} + labels = {key: torch.cat(inputs[key].labels, dim=0) for key in inputs} + imgs_info = {key: inputs[key].imgs_info for key in inputs} + return { + "images": images, + "labels": labels, + "imgs_info": imgs_info, + "mode": mode, + } + return { + "images": inputs.stacked_images, + "labels": torch.cat(inputs.labels, dim=0), + "imgs_info": inputs.imgs_info, + "mode": mode, + } + + def training_step(self, batch: MulticlassClsBatchDataEntity, batch_idx: int) -> Tensor: + """Performs a single training step on a batch of data. + + Args: + batch (MulticlassClsBatchDataEntity): The input batch of data. + batch_idx (int): The index of the current batch. + + Returns: + Tensor: The computed loss for the training step. + """ + loss = super().training_step(batch, batch_idx) + # Collect metrics related to Semi-SL Training. + self.log( + "train/unlabeled_coef", + self.model.head.unlabeled_coef, + on_step=True, + on_epoch=False, + prog_bar=True, + ) + self.log( + "train/num_pseudo_label", + self.model.head.num_pseudo_label, + on_step=True, + on_epoch=False, + prog_bar=True, + ) + return loss + + +class EfficientNetV2ForMultilabelCls(OTXMultilabelClsModel): + """EfficientNetV2 Model for multi-label classification task.""" + + def __init__( + self, + label_info: LabelInfoTypes, + optimizer: OptimizerCallable = DefaultOptimizerCallable, + scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable, + metric: MetricCallable = MultiLabelClsMetricCallable, + torch_compile: bool = False, + ) -> None: + super().__init__( + label_info=label_info, + optimizer=optimizer, + scheduler=scheduler, + metric=metric, + torch_compile=torch_compile, + ) + + def _create_model(self) -> nn.Module: + # Get classification_layers for class-incr learning + sample_model_dict = self._build_model(num_classes=5).state_dict() + incremental_model_dict = self._build_model(num_classes=6).state_dict() + self.classification_layers = get_classification_layers( + sample_model_dict, + incremental_model_dict, + prefix="model.", + ) + + model = self._build_model(num_classes=self.num_classes) + model.init_weights() + return model + + def _build_model(self, num_classes: int) -> nn.Module: + return ImageClassifier( + backbone=TimmBackbone(backbone="efficientnetv2_s_21k", pretrained=True), + neck=GlobalAveragePooling(dim=2), + head=MultiLabelLinearClsHead( + num_classes=num_classes, + in_channels=1280, + normalized=True, + scale=7.0, + loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"), + ), + ) + + def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.") -> dict: + """Load the previous OTX ckpt according to OTX2.0.""" + return OTXv1Helper.load_cls_effnet_v2_ckpt(state_dict, "multilabel", add_prefix) + + def _customize_inputs(self, inputs: MultilabelClsBatchDataEntity) -> dict[str, Any]: + if self.training: + mode = "loss" + elif self.explain_mode: + mode = "explain" + else: + mode = "predict" + + return { + "images": inputs.stacked_images, + "labels": torch.stack(inputs.labels), + "imgs_info": inputs.imgs_info, + "mode": mode, + } + + def _customize_outputs( + self, + outputs: Any, # noqa: ANN401 + inputs: MultilabelClsBatchDataEntity, + ) -> MultilabelClsBatchPredEntity | OTXBatchLossEntity: + if self.training: + return OTXBatchLossEntity(loss=outputs) + + # To list, batch-wise + logits = outputs if isinstance(outputs, torch.Tensor) else outputs["logits"] + scores = torch.unbind(logits, 0) + + return MultilabelClsBatchPredEntity( + batch_size=inputs.batch_size, + images=inputs.images, + imgs_info=inputs.imgs_info, + scores=scores, + labels=logits.argmax(-1, keepdim=True).unbind(0), + ) + + @property + def _exporter(self) -> OTXModelExporter: + """Creates OTXModelExporter object that can export the model.""" + return OTXNativeModelExporter( + task_level_export_parameters=self._export_parameters, + input_size=(1, 3, 224, 224), + mean=(123.675, 116.28, 103.53), + std=(58.395, 57.12, 57.375), + resize_mode="standard", + pad_value=0, + swap_rgb=False, + via_onnx=False, + onnx_export_configuration=None, + output_names=["logits", "feature_vector", "saliency_map"] if self.explain_mode else None, + ) + + def forward_explain(self, inputs: MultilabelClsBatchDataEntity) -> MultilabelClsBatchPredEntity: + """Model forward explain function.""" + outputs = self.model(images=inputs.stacked_images, mode="explain") + + return MultilabelClsBatchPredEntity( + batch_size=len(outputs["preds"]), + images=inputs.images, + imgs_info=inputs.imgs_info, + labels=outputs["preds"], + scores=outputs["scores"], + saliency_map=outputs["saliency_map"], + feature_vector=outputs["feature_vector"], + ) + + def forward_for_tracing(self, image: torch.Tensor) -> torch.Tensor | dict[str, torch.Tensor]: + """Model forward function used for the model tracing during model exportation.""" + if self.explain_mode: + return self.model(images=image, mode="explain") + + return self.model(images=image, mode="tensor") + + +class EfficientNetV2ForHLabelCls(OTXHlabelClsModel): + """EfficientNetV2 Model for hierarchical label classification task.""" + + label_info: HLabelInfo + + def __init__( + self, + label_info: HLabelInfo, + optimizer: OptimizerCallable = DefaultOptimizerCallable, + scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable, + metric: MetricCallable = HLabelClsMetricCallble, + torch_compile: bool = False, + ) -> None: + super().__init__( + label_info=label_info, + optimizer=optimizer, + scheduler=scheduler, + metric=metric, + torch_compile=torch_compile, + ) + + def _create_model(self) -> nn.Module: + # Get classification_layers for class-incr learning + sample_config = deepcopy(self.label_info.as_head_config_dict()) + sample_config["num_classes"] = 5 + sample_model_dict = self._build_model(head_config=sample_config).state_dict() + sample_config["num_classes"] = 6 + incremental_model_dict = self._build_model(head_config=sample_config).state_dict() + self.classification_layers = get_classification_layers( + sample_model_dict, + incremental_model_dict, + prefix="model.", + ) + + model = self._build_model(head_config=self.label_info.as_head_config_dict()) + model.init_weights() + return model + + def label_smoothing_loss(self, output, target, num_classes, smoothing=0.1): + confidence = 1.0 - smoothing + smoothed_labels = torch.full(size=(target.size(0), num_classes), fill_value=smoothing / (num_classes - 1)).to(target.device) + smoothed_labels.scatter_(1, target.unsqueeze(1), confidence) + loss = -torch.sum(smoothed_labels * nn.LogSoftmax(dim=1)(output), dim=1) + return loss.mean() + + def _build_model(self, head_config: dict) -> nn.Module: + return ImageClassifier( + backbone=TimmBackbone(backbone="efficientnetv2_s_21k", pretrained=True), + neck=GlobalAveragePooling(dim=2), + head=HierarchicalCBAMClsHead( + in_channels=1280, + multiclass_loss=self.label_smoothing_loss, + multilabel_loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"), + **head_config, + ), + ) + + def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.") -> dict: + """Load the previous OTX ckpt according to OTX2.0.""" + return OTXv1Helper.load_cls_effnet_v2_ckpt(state_dict, "hlabel", add_prefix) + + def _customize_inputs(self, inputs: HlabelClsBatchDataEntity) -> dict[str, Any]: + if self.training: + mode = "loss" + elif self.explain_mode: + mode = "explain" + else: + mode = "predict" + + return { + "images": inputs.stacked_images, + "labels": torch.stack(inputs.labels), + "imgs_info": inputs.imgs_info, + "mode": mode, + } + + def _customize_outputs( + self, + outputs: Any, # noqa: ANN401 + inputs: HlabelClsBatchDataEntity, + ) -> HlabelClsBatchPredEntity | OTXBatchLossEntity: + if self.training: + return OTXBatchLossEntity(loss=outputs) + + # To list, batch-wise + if isinstance(outputs, dict): + scores = outputs["scores"] + labels = outputs["labels"] + else: + scores = outputs + labels = outputs.argmax(-1, keepdim=True) + + return HlabelClsBatchPredEntity( + batch_size=inputs.batch_size, + images=inputs.images, + imgs_info=inputs.imgs_info, + scores=scores, + labels=labels, + ) + + def _convert_pred_entity_to_compute_metric( + self, + preds: HlabelClsBatchPredEntity, + inputs: HlabelClsBatchDataEntity, + ) -> MetricInput: + hlabel_info: HLabelInfo = self.label_info # type: ignore[assignment] + + _labels = torch.stack(preds.labels) if isinstance(preds.labels, list) else preds.labels + _scores = torch.stack(preds.scores) if isinstance(preds.scores, list) else preds.scores + if hlabel_info.num_multilabel_classes > 0: + preds_multiclass = _labels[:, : hlabel_info.num_multiclass_heads] + preds_multilabel = _scores[:, hlabel_info.num_multiclass_heads :] + pred_result = torch.cat([preds_multiclass, preds_multilabel], dim=1) + else: + pred_result = _labels + return { + "preds": pred_result, + "target": torch.stack(inputs.labels), + } + + @property + def _exporter(self) -> OTXModelExporter: + """Creates OTXModelExporter object that can export the model.""" + return OTXNativeModelExporter( + task_level_export_parameters=self._export_parameters, + input_size=(1, 3, 224, 224), + mean=(123.675, 116.28, 103.53), + std=(58.395, 57.12, 57.375), + resize_mode="standard", + pad_value=0, + swap_rgb=False, + via_onnx=False, + onnx_export_configuration=None, + output_names=["logits", "feature_vector", "saliency_map"] if self.explain_mode else None, + ) + + def forward_explain(self, inputs: HlabelClsBatchDataEntity) -> HlabelClsBatchPredEntity: + """Model forward explain function.""" + outputs = self.model(images=inputs.stacked_images, mode="explain") + + return HlabelClsBatchPredEntity( + batch_size=len(outputs["preds"]), + images=inputs.images, + imgs_info=inputs.imgs_info, + labels=outputs["preds"], + scores=outputs["scores"], + saliency_map=outputs["saliency_map"], + feature_vector=outputs["feature_vector"], + ) + + def forward_for_tracing(self, image: torch.Tensor) -> torch.Tensor | dict[str, torch.Tensor]: + """Model forward function used for the model tracing during model exportation.""" + if self.explain_mode: + return self.model(images=image, mode="explain") + + return self.model(images=image, mode="tensor") diff --git a/src/otx/algo/classification/heads/hlabel_cls_head.py b/src/otx/algo/classification/heads/hlabel_cls_head.py index 2a4289e9805..99e899b7991 100644 --- a/src/otx/algo/classification/heads/hlabel_cls_head.py +++ b/src/otx/algo/classification/heads/hlabel_cls_head.py @@ -103,7 +103,7 @@ def loss(self, feats: tuple[torch.Tensor], labels: torch.Tensor, **kwargs) -> to head_gt = head_gt[valid_mask] if len(head_gt) > 0: head_logits = head_logits[valid_mask, :] - loss_score += self.multiclass_loss(head_logits, head_gt) + loss_score += self.multiclass_loss(head_logits, head_gt, logit_range[1]-logit_range[0]) num_effective_heads_in_batch += 1 if num_effective_heads_in_batch > 0: From c37bd05958457e83baaf352421cd1e26eae0d923 Mon Sep 17 00:00:00 2001 From: sooahleex Date: Sat, 3 Aug 2024 14:07:28 +0900 Subject: [PATCH 03/16] Update head --- .../classification/heads/hlabel_cls_head.py | 25 ++----------------- 1 file changed, 2 insertions(+), 23 deletions(-) diff --git a/src/otx/algo/classification/heads/hlabel_cls_head.py b/src/otx/algo/classification/heads/hlabel_cls_head.py index 99e899b7991..880422455aa 100644 --- a/src/otx/algo/classification/heads/hlabel_cls_head.py +++ b/src/otx/algo/classification/heads/hlabel_cls_head.py @@ -441,38 +441,17 @@ def __init__( init_cfg=init_cfg, **kwargs, ) - self.hid_channels = hid_channels - self.dropout = dropout - - self.activation_callable = activation_callable - self.fc_superclass = nn.Linear(in_channels, num_multiclass_heads) self.attention_fc = nn.Linear(num_multiclass_heads, in_channels) self.cbam = CBAM(in_channels) - - classifier_modules = [ - nn.Linear(in_channels, hid_channels), - nn.BatchNorm1d(hid_channels), - self.activation_callable if isinstance(self.activation_callable, nn.Module) else self.activation_callable(), - ] - - if self.dropout: - classifier_modules.append(nn.Dropout(p=0.2)) - - classifier_modules.append(nn.Linear(hid_channels, num_classes)) - - self.fc_subclass = nn.Sequential(*classifier_modules) + self.fc_subclass = nn.Linear(in_channels, num_single_label_classes) self._init_layers() def _init_layers(self) -> None: """Iniitialize weights of classification head.""" normal_init(self.fc_superclass, mean=0, std=0.01, bias=0) - for module in self.fc_subclass: - if isinstance(module, nn.Linear): - normal_init(module, mean=0, std=0.01, bias=0) - elif isinstance(module, nn.BatchNorm1d): - constant_init(module, 1) + normal_init(self.fc_subclass, mean=0, std=0.01, bias=0) def forward(self, feats: tuple[torch.Tensor] | torch.Tensor) -> torch.Tensor: """The forward process.""" From b817dbc8a1ebefed2ab12a840d10480265ee1289 Mon Sep 17 00:00:00 2001 From: sooahleex Date: Mon, 5 Aug 2024 10:24:40 +0900 Subject: [PATCH 04/16] Use CBAM for efficientnet --- src/otx/algo/classification/efficientnet.py | 19 ++++++++--- .../classification/heads/hlabel_cls_head.py | 32 +++++++++++++++---- 2 files changed, 40 insertions(+), 11 deletions(-) diff --git a/src/otx/algo/classification/efficientnet.py b/src/otx/algo/classification/efficientnet.py index ce036ae0aa4..0cbae243303 100644 --- a/src/otx/algo/classification/efficientnet.py +++ b/src/otx/algo/classification/efficientnet.py @@ -19,6 +19,7 @@ LinearClsHead, MultiLabelLinearClsHead, OTXSemiSLLinearClsHead, + HierarchicalCBAMClsHead, ) from otx.algo.classification.losses.asymmetric_angular_loss_with_ignore import AsymmetricAngularLossWithIgnore from otx.algo.classification.necks.gap import GlobalAveragePooling @@ -258,17 +259,25 @@ def _create_model(self) -> nn.Module: model.init_weights() return model + def label_smoothing_loss(self, output, target, num_classes, smoothing=0.1): + confidence = 1.0 - smoothing + smoothed_labels = torch.full(size=(target.size(0), num_classes), fill_value=smoothing / (num_classes - 1)).to(target.device) + smoothed_labels.scatter_(1, target.unsqueeze(1), confidence) + loss = -torch.sum(smoothed_labels * nn.LogSoftmax(dim=1)(output), dim=1) + return loss.mean() + def _build_model(self, head_config: dict) -> nn.Module: if not isinstance(self.label_info, HLabelInfo): raise TypeError(self.label_info) backbone = OTXEfficientNet(version=self.version, pretrained=self.pretrained) return ImageClassifier( - backbone=backbone, - neck=GlobalAveragePooling(dim=2), - head=HierarchicalLinearClsHead( - in_channels=backbone.num_features, - multiclass_loss=nn.CrossEntropyLoss(), + backbone=OTXEfficientNet(version=self.version, pretrained=True), + # neck=GlobalAveragePooling(dim=2), + neck=nn.Identity(), + head=HierarchicalCBAMClsHead( + in_channels=1280, + multiclass_loss=self.label_smoothing_loss, multilabel_loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"), **head_config, ), diff --git a/src/otx/algo/classification/heads/hlabel_cls_head.py b/src/otx/algo/classification/heads/hlabel_cls_head.py index 880422455aa..7e1eed6fc22 100644 --- a/src/otx/algo/classification/heads/hlabel_cls_head.py +++ b/src/otx/algo/classification/heads/hlabel_cls_head.py @@ -103,7 +103,9 @@ def loss(self, feats: tuple[torch.Tensor], labels: torch.Tensor, **kwargs) -> to head_gt = head_gt[valid_mask] if len(head_gt) > 0: head_logits = head_logits[valid_mask, :] - loss_score += self.multiclass_loss(head_logits, head_gt, logit_range[1]-logit_range[0]) + num_classes = logit_range[1] - logit_range[0] + loss_score += self.multiclass_loss(head_logits, head_gt, num_classes) + # loss_score += self.multiclass_loss(head_logits, head_gt) num_effective_heads_in_batch += 1 if num_effective_heads_in_batch > 0: @@ -187,6 +189,15 @@ def _get_predictions( multiclass_pred_scores = torch.cat(multiclass_pred_scores, dim=1) multiclass_pred_labels = torch.cat(multiclass_pred_labels, dim=1) + # multiclass_pred = torch.softmax(cls_scores, dim=1) + # multiclass_pred_score, multiclass_pred_label = torch.max(multiclass_pred, dim=1) + + # multiclass_pred_scores.extend(multiclass_pred_score.view(-1, 1)) + # multiclass_pred_labels.extend(multiclass_pred_label.view(-1, 1)) + + # multiclass_pred_scores = torch.cat(multiclass_pred_scores) + # multiclass_pred_labels = torch.cat(multiclass_pred_labels) + if self.num_multilabel_classes > 0: multilabel_logits = cls_scores[:, self.num_single_label_classes :] @@ -441,13 +452,22 @@ def __init__( init_cfg=init_cfg, **kwargs, ) - self.fc_superclass = nn.Linear(in_channels, num_multiclass_heads) - self.attention_fc = nn.Linear(num_multiclass_heads, in_channels) + step_size = 7 + self.step_size = step_size + self.fc_superclass = nn.Linear(in_channels * step_size * step_size, num_multiclass_heads) + self.attention_fc = nn.Linear(num_multiclass_heads, in_channels * step_size * step_size) self.cbam = CBAM(in_channels) - self.fc_subclass = nn.Linear(in_channels, num_single_label_classes) + self.fc_subclass = nn.Linear(in_channels * step_size * step_size, num_single_label_classes) self._init_layers() + def pre_logits(self, feats: tuple[torch.Tensor] | torch.Tensor) -> torch.Tensor: + """The process before the final classification head.""" + if isinstance(feats, Sequence): + feats = feats[-1] + return feats.view(feats.size(0), self.in_channels*self.step_size*self.step_size) + # return feats[-1] + def _init_layers(self) -> None: """Iniitialize weights of classification head.""" normal_init(self.fc_superclass, mean=0, std=0.01, bias=0) @@ -461,9 +481,9 @@ def forward(self, feats: tuple[torch.Tensor] | torch.Tensor) -> torch.Tensor: attention_weights = torch.sigmoid(self.attention_fc(out_superclass)) attended_features = pre_logits * attention_weights - attended_features = attended_features.view(pre_logits.size(0), self.in_channels, 1, 1) + attended_features = attended_features.view(pre_logits.size(0), self.in_channels, self.step_size, self.step_size) attended_features = self.cbam(attended_features) - attended_features = attended_features.view(pre_logits.size(0), self.in_channels) + attended_features = attended_features.view(pre_logits.size(0), self.in_channels*self.step_size*self.step_size) out_subclass = self.fc_subclass(attended_features) return out_subclass From e731363441e0c88a0b546e49b72bc1f1cfe03bd7 Mon Sep 17 00:00:00 2001 From: sooahleex Date: Tue, 6 Aug 2024 15:36:10 +0900 Subject: [PATCH 05/16] Remove loss --- src/otx/algo/classification/efficientnet.py | 13 +--- .../algo/classification/efficientnet_v2.py | 14 +---- .../classification/heads/hlabel_cls_head.py | 63 ++++++++++--------- src/otx/algo/classification/mobilenet_v3.py | 6 +- 4 files changed, 43 insertions(+), 53 deletions(-) diff --git a/src/otx/algo/classification/efficientnet.py b/src/otx/algo/classification/efficientnet.py index 0cbae243303..e3aa05f8314 100644 --- a/src/otx/algo/classification/efficientnet.py +++ b/src/otx/algo/classification/efficientnet.py @@ -19,7 +19,6 @@ LinearClsHead, MultiLabelLinearClsHead, OTXSemiSLLinearClsHead, - HierarchicalCBAMClsHead, ) from otx.algo.classification.losses.asymmetric_angular_loss_with_ignore import AsymmetricAngularLossWithIgnore from otx.algo.classification.necks.gap import GlobalAveragePooling @@ -259,13 +258,6 @@ def _create_model(self) -> nn.Module: model.init_weights() return model - def label_smoothing_loss(self, output, target, num_classes, smoothing=0.1): - confidence = 1.0 - smoothing - smoothed_labels = torch.full(size=(target.size(0), num_classes), fill_value=smoothing / (num_classes - 1)).to(target.device) - smoothed_labels.scatter_(1, target.unsqueeze(1), confidence) - loss = -torch.sum(smoothed_labels * nn.LogSoftmax(dim=1)(output), dim=1) - return loss.mean() - def _build_model(self, head_config: dict) -> nn.Module: if not isinstance(self.label_info, HLabelInfo): raise TypeError(self.label_info) @@ -273,11 +265,10 @@ def _build_model(self, head_config: dict) -> nn.Module: backbone = OTXEfficientNet(version=self.version, pretrained=self.pretrained) return ImageClassifier( backbone=OTXEfficientNet(version=self.version, pretrained=True), - # neck=GlobalAveragePooling(dim=2), neck=nn.Identity(), - head=HierarchicalCBAMClsHead( + head=HierarchicalLinearClsHead( in_channels=1280, - multiclass_loss=self.label_smoothing_loss, + multiclass_loss=nn.CrossEntropyLoss(), multilabel_loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"), **head_config, ), diff --git a/src/otx/algo/classification/efficientnet_v2.py b/src/otx/algo/classification/efficientnet_v2.py index e42af4db7ad..a8b1bfb2fde 100644 --- a/src/otx/algo/classification/efficientnet_v2.py +++ b/src/otx/algo/classification/efficientnet_v2.py @@ -13,11 +13,10 @@ from otx.algo.classification.backbones.timm import TimmBackbone from otx.algo.classification.classifier import ImageClassifier, SemiSLClassifier from otx.algo.classification.heads import ( - HierarchicalLinearClsHead, + HierarchicalCBAMClsHead, LinearClsHead, MultiLabelLinearClsHead, OTXSemiSLLinearClsHead, - HierarchicalCBAMClsHead, ) from otx.algo.classification.losses.asymmetric_angular_loss_with_ignore import AsymmetricAngularLossWithIgnore from otx.algo.classification.necks.gap import GlobalAveragePooling @@ -419,20 +418,13 @@ def _create_model(self) -> nn.Module: model.init_weights() return model - def label_smoothing_loss(self, output, target, num_classes, smoothing=0.1): - confidence = 1.0 - smoothing - smoothed_labels = torch.full(size=(target.size(0), num_classes), fill_value=smoothing / (num_classes - 1)).to(target.device) - smoothed_labels.scatter_(1, target.unsqueeze(1), confidence) - loss = -torch.sum(smoothed_labels * nn.LogSoftmax(dim=1)(output), dim=1) - return loss.mean() - def _build_model(self, head_config: dict) -> nn.Module: return ImageClassifier( backbone=TimmBackbone(backbone="efficientnetv2_s_21k", pretrained=True), - neck=GlobalAveragePooling(dim=2), + neck=nn.Identity(), head=HierarchicalCBAMClsHead( in_channels=1280, - multiclass_loss=self.label_smoothing_loss, + multiclass_loss=nn.CrossEntropyLoss(), multilabel_loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"), **head_config, ), diff --git a/src/otx/algo/classification/heads/hlabel_cls_head.py b/src/otx/algo/classification/heads/hlabel_cls_head.py index 7e1eed6fc22..842d7fb2e63 100644 --- a/src/otx/algo/classification/heads/hlabel_cls_head.py +++ b/src/otx/algo/classification/heads/hlabel_cls_head.py @@ -103,9 +103,7 @@ def loss(self, feats: tuple[torch.Tensor], labels: torch.Tensor, **kwargs) -> to head_gt = head_gt[valid_mask] if len(head_gt) > 0: head_logits = head_logits[valid_mask, :] - num_classes = logit_range[1] - logit_range[0] - loss_score += self.multiclass_loss(head_logits, head_gt, num_classes) - # loss_score += self.multiclass_loss(head_logits, head_gt) + loss_score += self.multiclass_loss(head_logits, head_gt) num_effective_heads_in_batch += 1 if num_effective_heads_in_batch > 0: @@ -365,38 +363,53 @@ def forward(self, feats: tuple[torch.Tensor] | torch.Tensor) -> torch.Tensor: pre_logits = self.pre_logits(feats) return self.classifier(pre_logits) + class ChannelAttention(nn.Module): - def __init__(self, in_channels, reduction=16): - super(ChannelAttention, self).__init__() + """Channel attention module that uses average and max pooling to enhance important channels.""" + + def __init__(self, in_channels: int, reduction: int = 16): + """Initializes the ChannelAttention module.""" + super().__init__(ChannelAttention, self) self.fc1 = nn.Conv2d(in_channels, in_channels // reduction, kernel_size=1, bias=False) self.fc2 = nn.Conv2d(in_channels // reduction, in_channels, kernel_size=1, bias=False) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Applies channel attention to the input tensor.""" avg_out = self.fc2(torch.relu(self.fc1(torch.mean(x, dim=2, keepdim=True).mean(dim=3, keepdim=True)))) max_out = self.fc2(torch.relu(self.fc1(torch.max(x, dim=2, keepdim=True)[0].max(dim=3, keepdim=True)[0]))) return torch.sigmoid(avg_out + max_out) + class SpatialAttention(nn.Module): - def __init__(self, kernel_size=7): - super(SpatialAttention, self).__init__() - self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False) + """Spatial attention module that uses average and max pooling to enhance important spatial locations.""" - def forward(self, x): + def __init__(self, kernel_size: int = 7): + """Initializes the SpatialAttention module.""" + super().__init__(SpatialAttention, self) + self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Applies spatial attention to the input tensor.""" avg_out = torch.mean(x, dim=1, keepdim=True) max_out = torch.max(x, dim=1, keepdim=True)[0] x = torch.cat([avg_out, max_out], dim=1) return torch.sigmoid(self.conv(x)) + class CBAM(nn.Module): - def __init__(self, in_channels, reduction=16, kernel_size=7): - super(CBAM, self).__init__() + """CBAM module that applies channel and spatial attention sequentially.""" + + def __init__(self, in_channels: int, reduction: int = 16, kernel_size: int = 7): + """Initializes the CBAM module with channel and spatial attention.""" + super().__init__(CBAM, self) self.channel_attention = ChannelAttention(in_channels, reduction) self.spatial_attention = SpatialAttention(kernel_size) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Applies channel and spatial attention to the input tensor.""" x = x * self.channel_attention(x) - x = x * self.spatial_attention(x) - return x + return x * self.spatial_attention(x) + class HierarchicalCBAMClsHead(HierarchicalClsHead): """Custom classification CBAM head for hierarchical classification task. @@ -414,10 +427,6 @@ class HierarchicalCBAMClsHead(HierarchicalClsHead): multilabel_loss (dict | None): Config of multi-label loss. thr (float | None): Predictions with scores under the thresholds are considered as negative. Defaults to 0.5. - hid_cahnnels (int): Number of channels in the hidden feature map at the classifier. - acivation_Cfg (dict | None): Config of activation layer at the classifier. - dropout (bool): Flag for the enabling the dropout at the classifier. - """ def __init__( @@ -432,9 +441,6 @@ def __init__( multiclass_loss: nn.Module, multilabel_loss: nn.Module | None = None, thr: float = 0.5, - hid_channels: int = 1280, - activation_callable: Callable[[], nn.Module] = nn.ReLU, - dropout: bool = False, init_cfg: dict | None = None, **kwargs, ): @@ -465,8 +471,8 @@ def pre_logits(self, feats: tuple[torch.Tensor] | torch.Tensor) -> torch.Tensor: """The process before the final classification head.""" if isinstance(feats, Sequence): feats = feats[-1] - return feats.view(feats.size(0), self.in_channels*self.step_size*self.step_size) - # return feats[-1] + return feats.view(feats.size(0), self.in_channels * self.step_size * self.step_size) + return feats def _init_layers(self) -> None: """Iniitialize weights of classification head.""" @@ -483,7 +489,8 @@ def forward(self, feats: tuple[torch.Tensor] | torch.Tensor) -> torch.Tensor: attended_features = attended_features.view(pre_logits.size(0), self.in_channels, self.step_size, self.step_size) attended_features = self.cbam(attended_features) - attended_features = attended_features.view(pre_logits.size(0), self.in_channels*self.step_size*self.step_size) - out_subclass = self.fc_subclass(attended_features) - - return out_subclass + attended_features = attended_features.view( + pre_logits.size(0), + self.in_channels * self.step_size * self.step_size, + ) + return self.fc_subclass(attended_features) diff --git a/src/otx/algo/classification/mobilenet_v3.py b/src/otx/algo/classification/mobilenet_v3.py index 756b52721ae..d2696f59ce8 100644 --- a/src/otx/algo/classification/mobilenet_v3.py +++ b/src/otx/algo/classification/mobilenet_v3.py @@ -15,7 +15,7 @@ from otx.algo.classification.backbones import OTXMobileNetV3 from otx.algo.classification.classifier import ImageClassifier, SemiSLClassifier from otx.algo.classification.heads import ( - HierarchicalNonLinearClsHead, + HierarchicalCBAMClsHead, LinearClsHead, MultiLabelNonLinearClsHead, OTXSemiSLLinearClsHead, @@ -325,8 +325,8 @@ def _build_model(self, head_config: dict) -> nn.Module: return ImageClassifier( backbone=OTXMobileNetV3(mode=self.mode), - neck=GlobalAveragePooling(dim=2), - head=HierarchicalNonLinearClsHead( + neck=nn.Identity(), + head=HierarchicalCBAMClsHead( in_channels=960, multiclass_loss=nn.CrossEntropyLoss(), multilabel_loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"), From 8a9aa73722e0e788c6754d773ff15b770b61fff0 Mon Sep 17 00:00:00 2001 From: sooahleex Date: Wed, 7 Aug 2024 17:08:26 +0900 Subject: [PATCH 06/16] Update comments --- .../classification/heads/hlabel_cls_head.py | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/src/otx/algo/classification/heads/hlabel_cls_head.py b/src/otx/algo/classification/heads/hlabel_cls_head.py index 842d7fb2e63..1174211eed2 100644 --- a/src/otx/algo/classification/heads/hlabel_cls_head.py +++ b/src/otx/algo/classification/heads/hlabel_cls_head.py @@ -187,15 +187,6 @@ def _get_predictions( multiclass_pred_scores = torch.cat(multiclass_pred_scores, dim=1) multiclass_pred_labels = torch.cat(multiclass_pred_labels, dim=1) - # multiclass_pred = torch.softmax(cls_scores, dim=1) - # multiclass_pred_score, multiclass_pred_label = torch.max(multiclass_pred, dim=1) - - # multiclass_pred_scores.extend(multiclass_pred_score.view(-1, 1)) - # multiclass_pred_labels.extend(multiclass_pred_label.view(-1, 1)) - - # multiclass_pred_scores = torch.cat(multiclass_pred_scores) - # multiclass_pred_labels = torch.cat(multiclass_pred_labels) - if self.num_multilabel_classes > 0: multilabel_logits = cls_scores[:, self.num_single_label_classes :] @@ -369,7 +360,7 @@ class ChannelAttention(nn.Module): def __init__(self, in_channels: int, reduction: int = 16): """Initializes the ChannelAttention module.""" - super().__init__(ChannelAttention, self) + super().__init__() self.fc1 = nn.Conv2d(in_channels, in_channels // reduction, kernel_size=1, bias=False) self.fc2 = nn.Conv2d(in_channels // reduction, in_channels, kernel_size=1, bias=False) @@ -385,7 +376,7 @@ class SpatialAttention(nn.Module): def __init__(self, kernel_size: int = 7): """Initializes the SpatialAttention module.""" - super().__init__(SpatialAttention, self) + super().__init__() self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -401,7 +392,7 @@ class CBAM(nn.Module): def __init__(self, in_channels: int, reduction: int = 16, kernel_size: int = 7): """Initializes the CBAM module with channel and spatial attention.""" - super().__init__(CBAM, self) + super().__init__() self.channel_attention = ChannelAttention(in_channels, reduction) self.spatial_attention = SpatialAttention(kernel_size) @@ -442,6 +433,7 @@ def __init__( multilabel_loss: nn.Module | None = None, thr: float = 0.5, init_cfg: dict | None = None, + step_size: int = 7, **kwargs, ): super().__init__( @@ -458,7 +450,6 @@ def __init__( init_cfg=init_cfg, **kwargs, ) - step_size = 7 self.step_size = step_size self.fc_superclass = nn.Linear(in_channels * step_size * step_size, num_multiclass_heads) self.attention_fc = nn.Linear(num_multiclass_heads, in_channels * step_size * step_size) From 552e782dba638ca6f653381912b4e180be235c82 Mon Sep 17 00:00:00 2001 From: sooahleex Date: Wed, 7 Aug 2024 17:10:46 +0900 Subject: [PATCH 07/16] Update update h-label format --- src/otx/core/data/dataset/classification.py | 22 +++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/src/otx/core/data/dataset/classification.py b/src/otx/core/data/dataset/classification.py index 1d0b82524cb..57170da967b 100644 --- a/src/otx/core/data/dataset/classification.py +++ b/src/otx/core/data/dataset/classification.py @@ -145,7 +145,6 @@ def _get_label_group_idx(label_name: str) -> int: def _find_ancestor_recursively(label_name: str, ancestors: list) -> list[str]: _, dm_label_category = self.dm_categories.find(label_name) - # parent_name = dm_label_category.parent parent_name = dm_label_category.parent if dm_label_category else "" if parent_name != "": @@ -201,21 +200,22 @@ def _convert_label_to_hlabel_format(self, label_anns: list[Label], ignored_label """Convert format of the label to the h-label. It converts the label format to h-label format. + Total length of result is sum of number of hierarchy and number of multilabel classes. i.e. Let's assume that we used the same dataset with example of the definition of HLabelData - and the original labels are ["Rigid", "Panda", "Lion"]. + and the original labels are ["Rigid", "Triangle", "Lion"]. - Then, h-label format will be [1, -1, 0, 1, 1]. + Then, h-label format will be [0, 1, 1, 0]. The first N-th indices represent the label index of multiclass heads (N=num_multiclass_heads), others represent the multilabel labels. - [Multiclass Heads: [1, -1]] - 0-th index = 1 -> ["Non-Rigid"(X), "Rigid"(O)] <- First multiclass head - 1-st index = -1 -> ["Rectangle"(X), "Triangle"(X)] <- Second multiclass head + [Multiclass Heads] + 0-th index = 0 -> ["Rigid"(O), "Non-Rigid"(X)] <- First multiclass head + 1-st index = 1 -> ["Rectangle"(O), "Triangle"(X), "Circle"(X)] <- Second multiclass head - [Multilabel Head: [0, 1, 1]] - 2, 3, 4 indices = [0, 1, 1] -> ["Circle"(X), "Lion"(O), "Panda"(O)] + [Multilabel Head] + 2, 3 indices = [1, 0] -> ["Lion"(O), "Panda"(X)] """ if not isinstance(self.label_info, HLabelInfo): msg = f"The type of label_info should be HLabelInfo, got {type(self.label_info)}." @@ -230,10 +230,16 @@ def _convert_label_to_hlabel_format(self, label_anns: list[Label], ignored_label for ann in label_anns: ann_name = self.dm_categories.items[ann.label].name + ann_parent = self.dm_categories.items[ann.label].parent group_idx, in_group_idx = self.label_info.class_to_group_idx[ann_name] + (parent_group_idx, parent_in_group_idx) = ( + self.label_info.class_to_group_idx[ann_parent] if ann_parent else (None, None) + ) if group_idx < num_multiclass_heads: class_indices[group_idx] = in_group_idx + if parent_group_idx is not None and parent_in_group_idx is not None: + class_indices[parent_group_idx] = parent_in_group_idx elif not ignored_labels or ann.label not in ignored_labels: class_indices[num_multiclass_heads + in_group_idx] = 1 else: From 9e5d653237ea3fa1549a3ed10e9164a3edae1b82 Mon Sep 17 00:00:00 2001 From: sooahleex Date: Sat, 10 Aug 2024 10:49:28 +0900 Subject: [PATCH 08/16] Add hlabel info for h-label head test --- src/otx/algo/classification/efficientnet.py | 2 +- src/otx/algo/classification/heads/__init__.py | 2 +- tests/unit/algo/classification/conftest.py | 75 +++++++++++++++++++ .../heads/test_hlabel_cls_head.py | 15 ++-- 4 files changed, 87 insertions(+), 7 deletions(-) diff --git a/src/otx/algo/classification/efficientnet.py b/src/otx/algo/classification/efficientnet.py index e3aa05f8314..9ea1dde5d3e 100644 --- a/src/otx/algo/classification/efficientnet.py +++ b/src/otx/algo/classification/efficientnet.py @@ -264,7 +264,7 @@ def _build_model(self, head_config: dict) -> nn.Module: backbone = OTXEfficientNet(version=self.version, pretrained=self.pretrained) return ImageClassifier( - backbone=OTXEfficientNet(version=self.version, pretrained=True), + backbone=backbone, neck=nn.Identity(), head=HierarchicalLinearClsHead( in_channels=1280, diff --git a/src/otx/algo/classification/heads/__init__.py b/src/otx/algo/classification/heads/__init__.py index 435688be314..a920d6782bb 100644 --- a/src/otx/algo/classification/heads/__init__.py +++ b/src/otx/algo/classification/heads/__init__.py @@ -3,7 +3,7 @@ # """Head modules for OTX custom model.""" -from .hlabel_cls_head import HierarchicalLinearClsHead, HierarchicalNonLinearClsHead, HierarchicalCBAMClsHead +from .hlabel_cls_head import HierarchicalCBAMClsHead, HierarchicalLinearClsHead, HierarchicalNonLinearClsHead from .linear_head import LinearClsHead from .multilabel_cls_head import MultiLabelLinearClsHead, MultiLabelNonLinearClsHead from .semi_sl_head import OTXSemiSLLinearClsHead, OTXSemiSLVisionTransformerClsHead diff --git a/tests/unit/algo/classification/conftest.py b/tests/unit/algo/classification/conftest.py index 825ce7cabb1..24fb086a14f 100644 --- a/tests/unit/algo/classification/conftest.py +++ b/tests/unit/algo/classification/conftest.py @@ -132,6 +132,81 @@ def fxt_hlabel_multilabel_info() -> HLabelInfo: ) +@pytest.fixture() +def fxt_hlabel_cifar() -> HLabelInfo: + return HLabelInfo( + label_names=[ + "beaver", + "dolphin", + "otter", + "seal", + "whale", + "aquarium_fish", + "flatfish", + "ray", + "shark", + "trout", + "aquatic_mammals", + "fish", + ], + label_groups=[ + ["beaver", "dolphin", "otter", "seal", "whale"], + ["aquarium_fish", "flatfish", "ray", "shark", "trout"], + ["aquatic_mammals", "fish"], + ], + num_multiclass_heads=3, + num_multilabel_classes=0, + head_idx_to_logits_range={"0": (0, 2), "1": (2, 7), "2": (7, 12)}, + num_single_label_classes=10, + empty_multiclass_head_indices=[], + class_to_group_idx={ + "beaver": (0, 0), + "dolphin": (0, 1), + "otter": (0, 2), + "seal": (0, 3), + "whale": (0, 4), + "aquarium_fish": (1, 0), + "flatfish": (1, 1), + "ray": (1, 2), + "shark": (1, 3), + "trout": (1, 4), + "aquatic_mammals": (2, 0), + "fish": (2, 1), + }, + all_groups=[ + ["beaver", "dolphin", "otter", "seal", "whale"], + ["aquarium_fish", "flatfish", "ray", "shark", "trout"], + ["aquatic_mammals", "fish"], + ], + label_to_idx={ + "aquarium_fish": 0, + "beaver": 1, + "dolphin": 2, + "flatfish": 3, + "otter": 4, + "ray": 5, + "seal": 6, + "shark": 7, + "trout": 8, + "whale": 9, + "aquatic_mammals": 10, + "fish": 11, + }, + label_tree_edges=[ + ["aquarium_fish", "fish"], + ["beaver", "aquatic_mammals"], + ["dolphin", "aquatic_mammals"], + ["otter", "aquatic_mammals"], + ["seal", "aquatic_mammals"], + ["whale", "aquatic_mammals"], + ["flatfish", "aquarium_fish"], + ["ray", "aquarium_fish"], + ["shark", "aquarium_fish"], + ["trout", "aquarium_fish"], + ], + ) + + @pytest.fixture() def fxt_multiclass_cls_batch_data_entity() -> MulticlassClsBatchDataEntity: batch_size = 2 diff --git a/tests/unit/algo/classification/heads/test_hlabel_cls_head.py b/tests/unit/algo/classification/heads/test_hlabel_cls_head.py index 4511ac7fc49..9edc1f53d30 100644 --- a/tests/unit/algo/classification/heads/test_hlabel_cls_head.py +++ b/tests/unit/algo/classification/heads/test_hlabel_cls_head.py @@ -8,7 +8,11 @@ import pytest import torch -from otx.algo.classification.heads import HierarchicalLinearClsHead, HierarchicalNonLinearClsHead, HierarchicalCBAMClsHead +from otx.algo.classification.heads import ( + HierarchicalCBAMClsHead, + HierarchicalLinearClsHead, + HierarchicalNonLinearClsHead, +) from otx.algo.classification.losses import AsymmetricAngularLossWithIgnore from otx.core.data.entity.base import ImageInfo from torch import nn @@ -49,9 +53,9 @@ def fxt_data_sample_with_ignored_labels() -> dict: class TestHierarchicalLinearClsHead: @pytest.fixture() - def fxt_head_attrs(self, fxt_hlabel_multilabel_info) -> dict[str, Any]: + def fxt_head_attrs(self, fxt_hlabel_cifar) -> dict[str, Any]: return { - **fxt_hlabel_multilabel_info.as_head_config_dict(), + **fxt_hlabel_cifar.as_head_config_dict(), "in_channels": 24, "multiclass_loss": CrossEntropyLoss(), "multilabel_loss": AsymmetricAngularLossWithIgnore(), @@ -67,6 +71,7 @@ def fxt_hlabel_non_linear_head(self, fxt_head_attrs) -> nn.Module: @pytest.fixture() def fxt_hlabel_cbam_head(self, fxt_head_attrs) -> nn.Module: + fxt_head_attrs["step_size"] = 1 return HierarchicalCBAMClsHead(**fxt_head_attrs) @pytest.fixture(params=["fxt_hlabel_linear_head", "fxt_hlabel_non_linear_head", "fxt_hlabel_cbam_head"]) @@ -97,6 +102,6 @@ def test_predict( result = fxt_hlabel_head.predict(dummy_input, **fxt_data_sample) assert isinstance(result, dict) assert "scores" in result - assert result["scores"].shape == (2, 6) + assert result["scores"].shape == (2, 3) assert "labels" in result - assert result["labels"].shape == (2, 6) + assert result["labels"].shape == (2, 3) From 5df11af6f7a95bad2efc29570a29b176d9b38479 Mon Sep 17 00:00:00 2001 From: sooahleex Date: Sat, 10 Aug 2024 13:05:48 +0900 Subject: [PATCH 09/16] Update head for efficientnet --- src/otx/algo/classification/efficientnet.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/otx/algo/classification/efficientnet.py b/src/otx/algo/classification/efficientnet.py index 9ea1dde5d3e..35245a550f7 100644 --- a/src/otx/algo/classification/efficientnet.py +++ b/src/otx/algo/classification/efficientnet.py @@ -15,7 +15,7 @@ from otx.algo.classification.classifier.base_classifier import ImageClassifier from otx.algo.classification.classifier.semi_sl_classifier import SemiSLClassifier from otx.algo.classification.heads import ( - HierarchicalLinearClsHead, + HierarchicalCBAMClsHead, LinearClsHead, MultiLabelLinearClsHead, OTXSemiSLLinearClsHead, @@ -266,7 +266,7 @@ def _build_model(self, head_config: dict) -> nn.Module: return ImageClassifier( backbone=backbone, neck=nn.Identity(), - head=HierarchicalLinearClsHead( + head=HierarchicalCBAMClsHead( in_channels=1280, multiclass_loss=nn.CrossEntropyLoss(), multilabel_loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"), From 852c3ec0d5cabbba6339ac303cf9d49d64379984 Mon Sep 17 00:00:00 2001 From: sooahleex Date: Sat, 10 Aug 2024 14:43:00 +0900 Subject: [PATCH 10/16] Update dataset for tests of h-label cls --- src/otx/algo/classification/classifier/base_classifier.py | 3 ++- src/otx/algo/classification/efficientnet.py | 1 + src/otx/algo/classification/efficientnet_v2.py | 1 + src/otx/algo/classification/heads/hlabel_cls_head.py | 3 +-- src/otx/algo/classification/mobilenet_v3.py | 1 + src/otx/algo/classification/torchvision_model.py | 4 ++-- src/otx/algo/classification/vit.py | 5 +++-- tests/unit/algo/classification/conftest.py | 4 ++-- tests/unit/algo/classification/test_deit_tiny.py | 2 +- tests/unit/algo/classification/test_efficientnet.py | 4 ++-- tests/unit/algo/classification/test_mobilenet_v3.py | 4 ++-- 11 files changed, 18 insertions(+), 14 deletions(-) diff --git a/src/otx/algo/classification/classifier/base_classifier.py b/src/otx/algo/classification/classifier/base_classifier.py index 174dee052a0..3c5126824b2 100644 --- a/src/otx/algo/classification/classifier/base_classifier.py +++ b/src/otx/algo/classification/classifier/base_classifier.py @@ -61,6 +61,7 @@ def __init__( neck: nn.Module | None, head: nn.Module, pretrained: str | None = None, + optimize_gap: bool = True, mean: list[float] | None = None, std: list[float] | None = None, to_rgb: bool = False, @@ -81,7 +82,7 @@ def __init__( self.explainer = ReciproCAM( self._head_forward_fn, num_classes=head.num_classes, - optimize_gap=True, + optimize_gap=optimize_gap, ) def forward( diff --git a/src/otx/algo/classification/efficientnet.py b/src/otx/algo/classification/efficientnet.py index 35245a550f7..d8efb9faefe 100644 --- a/src/otx/algo/classification/efficientnet.py +++ b/src/otx/algo/classification/efficientnet.py @@ -272,6 +272,7 @@ def _build_model(self, head_config: dict) -> nn.Module: multilabel_loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"), **head_config, ), + optimize_gap=False, ) def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.") -> dict: diff --git a/src/otx/algo/classification/efficientnet_v2.py b/src/otx/algo/classification/efficientnet_v2.py index a8b1bfb2fde..78c369d93f0 100644 --- a/src/otx/algo/classification/efficientnet_v2.py +++ b/src/otx/algo/classification/efficientnet_v2.py @@ -428,6 +428,7 @@ def _build_model(self, head_config: dict) -> nn.Module: multilabel_loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"), **head_config, ), + optimize_gap=False, ) def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.") -> dict: diff --git a/src/otx/algo/classification/heads/hlabel_cls_head.py b/src/otx/algo/classification/heads/hlabel_cls_head.py index 1174211eed2..07c0c97df05 100644 --- a/src/otx/algo/classification/heads/hlabel_cls_head.py +++ b/src/otx/algo/classification/heads/hlabel_cls_head.py @@ -462,8 +462,7 @@ def pre_logits(self, feats: tuple[torch.Tensor] | torch.Tensor) -> torch.Tensor: """The process before the final classification head.""" if isinstance(feats, Sequence): feats = feats[-1] - return feats.view(feats.size(0), self.in_channels * self.step_size * self.step_size) - return feats + return feats.view(feats.size(0), self.in_channels * self.step_size * self.step_size) def _init_layers(self) -> None: """Iniitialize weights of classification head.""" diff --git a/src/otx/algo/classification/mobilenet_v3.py b/src/otx/algo/classification/mobilenet_v3.py index d2696f59ce8..5697427f4b7 100644 --- a/src/otx/algo/classification/mobilenet_v3.py +++ b/src/otx/algo/classification/mobilenet_v3.py @@ -332,6 +332,7 @@ def _build_model(self, head_config: dict) -> nn.Module: multilabel_loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"), **head_config, ), + optimize_gap=False, ) def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.") -> dict: diff --git a/src/otx/algo/classification/torchvision_model.py b/src/otx/algo/classification/torchvision_model.py index 7a81b536965..4702ab63c1a 100644 --- a/src/otx/algo/classification/torchvision_model.py +++ b/src/otx/algo/classification/torchvision_model.py @@ -11,7 +11,7 @@ from torch import Tensor, nn from torchvision.models import get_model, get_model_weights -from otx.algo.classification.heads import HierarchicalLinearClsHead, MultiLabelLinearClsHead, OTXSemiSLLinearClsHead +from otx.algo.classification.heads import HierarchicalCBAMClsHead, MultiLabelLinearClsHead, OTXSemiSLLinearClsHead from otx.algo.classification.losses import AsymmetricAngularLossWithIgnore from otx.algo.explain.explain_algo import ReciproCAM, feature_vector_fn from otx.core.data.entity.base import OTXBatchLossEntity @@ -210,7 +210,7 @@ def _get_head(self, net: nn.Module) -> nn.Module: ) if self.task == OTXTaskType.H_LABEL_CLS: self.neck = nn.Sequential(*layers) if layers else None - return HierarchicalLinearClsHead( + return HierarchicalCBAMClsHead( in_channels=feature_channel, multiclass_loss=nn.CrossEntropyLoss(), multilabel_loss=self.loss_module, diff --git a/src/otx/algo/classification/vit.py b/src/otx/algo/classification/vit.py index f2ccd09b8d9..a6638249f24 100644 --- a/src/otx/algo/classification/vit.py +++ b/src/otx/algo/classification/vit.py @@ -18,7 +18,7 @@ from otx.algo.classification.backbones.vision_transformer import VIT_ARCH_TYPE, VisionTransformer from otx.algo.classification.classifier import ImageClassifier, SemiSLClassifier from otx.algo.classification.heads import ( - HierarchicalLinearClsHead, + HierarchicalCBAMClsHead, MultiLabelLinearClsHead, OTXSemiSLVisionTransformerClsHead, VisionTransformerClsHead, @@ -494,10 +494,11 @@ def _build_model(self, head_config: dict) -> nn.Module: return ImageClassifier( backbone=vit_backbone, neck=None, - head=HierarchicalLinearClsHead( + head=HierarchicalCBAMClsHead( in_channels=vit_backbone.embed_dim, multiclass_loss=nn.CrossEntropyLoss(), multilabel_loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"), + step_size=1, **head_config, ), init_cfg=init_cfg, diff --git a/tests/unit/algo/classification/conftest.py b/tests/unit/algo/classification/conftest.py index 24fb086a14f..945c3d0bc4c 100644 --- a/tests/unit/algo/classification/conftest.py +++ b/tests/unit/algo/classification/conftest.py @@ -156,8 +156,8 @@ def fxt_hlabel_cifar() -> HLabelInfo: ], num_multiclass_heads=3, num_multilabel_classes=0, - head_idx_to_logits_range={"0": (0, 2), "1": (2, 7), "2": (7, 12)}, - num_single_label_classes=10, + head_idx_to_logits_range={"0": (0, 5), "1": (5, 10), "2": (10, 12)}, + num_single_label_classes=12, empty_multiclass_head_indices=[], class_to_group_idx={ "beaver": (0, 0), diff --git a/tests/unit/algo/classification/test_deit_tiny.py b/tests/unit/algo/classification/test_deit_tiny.py index 420aa50ea35..908c91bb60e 100644 --- a/tests/unit/algo/classification/test_deit_tiny.py +++ b/tests/unit/algo/classification/test_deit_tiny.py @@ -18,7 +18,7 @@ class TestDeitTiny: params=[ (VisionTransformerForMulticlassCls, "fxt_multiclass_cls_batch_data_entity", "fxt_multiclass_labelinfo"), (VisionTransformerForMultilabelCls, "fxt_multilabel_cls_batch_data_entity", "fxt_multilabel_labelinfo"), - (VisionTransformerForHLabelCls, "fxt_hlabel_cls_batch_data_entity", "fxt_hlabel_data"), + (VisionTransformerForHLabelCls, "fxt_hlabel_cls_batch_data_entity", "fxt_hlabel_cifar"), ], ids=["multiclass", "multilabel", "hlabel"], ) diff --git a/tests/unit/algo/classification/test_efficientnet.py b/tests/unit/algo/classification/test_efficientnet.py index 49d16527f7a..b2f2edc688a 100644 --- a/tests/unit/algo/classification/test_efficientnet.py +++ b/tests/unit/algo/classification/test_efficientnet.py @@ -94,10 +94,10 @@ def test_predict_step(self, fxt_multi_label_cls_model, fxt_multilabel_cls_batch_ @pytest.fixture() -def fxt_h_label_cls_model(fxt_hlabel_data): +def fxt_h_label_cls_model(fxt_hlabel_cifar): return EfficientNetForHLabelCls( version="b0", - label_info=fxt_hlabel_data, + label_info=fxt_hlabel_cifar, ) diff --git a/tests/unit/algo/classification/test_mobilenet_v3.py b/tests/unit/algo/classification/test_mobilenet_v3.py index 60981098e1c..62ebcb6ed03 100644 --- a/tests/unit/algo/classification/test_mobilenet_v3.py +++ b/tests/unit/algo/classification/test_mobilenet_v3.py @@ -94,10 +94,10 @@ def test_predict_step(self, fxt_multi_label_cls_model, fxt_multilabel_cls_batch_ @pytest.fixture() -def fxt_h_label_cls_model(fxt_hlabel_data): +def fxt_h_label_cls_model(fxt_hlabel_cifar): return MobileNetV3ForHLabelCls( mode="large", - label_info=fxt_hlabel_data, + label_info=fxt_hlabel_cifar, ) From 2466e4e59280919cefee79aef533a03ea5186d26 Mon Sep 17 00:00:00 2001 From: sooahleex Date: Sat, 10 Aug 2024 15:13:37 +0900 Subject: [PATCH 11/16] Update not use optimize_gap --- src/otx/algo/classification/classifier/base_classifier.py | 2 -- src/otx/algo/classification/efficientnet.py | 1 - src/otx/algo/classification/efficientnet_v2.py | 1 - src/otx/algo/classification/mobilenet_v3.py | 1 - src/otx/algo/classification/torchvision_model.py | 3 ++- 5 files changed, 2 insertions(+), 6 deletions(-) diff --git a/src/otx/algo/classification/classifier/base_classifier.py b/src/otx/algo/classification/classifier/base_classifier.py index 3c5126824b2..062bbca450b 100644 --- a/src/otx/algo/classification/classifier/base_classifier.py +++ b/src/otx/algo/classification/classifier/base_classifier.py @@ -61,7 +61,6 @@ def __init__( neck: nn.Module | None, head: nn.Module, pretrained: str | None = None, - optimize_gap: bool = True, mean: list[float] | None = None, std: list[float] | None = None, to_rgb: bool = False, @@ -82,7 +81,6 @@ def __init__( self.explainer = ReciproCAM( self._head_forward_fn, num_classes=head.num_classes, - optimize_gap=optimize_gap, ) def forward( diff --git a/src/otx/algo/classification/efficientnet.py b/src/otx/algo/classification/efficientnet.py index d8efb9faefe..35245a550f7 100644 --- a/src/otx/algo/classification/efficientnet.py +++ b/src/otx/algo/classification/efficientnet.py @@ -272,7 +272,6 @@ def _build_model(self, head_config: dict) -> nn.Module: multilabel_loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"), **head_config, ), - optimize_gap=False, ) def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.") -> dict: diff --git a/src/otx/algo/classification/efficientnet_v2.py b/src/otx/algo/classification/efficientnet_v2.py index 78c369d93f0..a8b1bfb2fde 100644 --- a/src/otx/algo/classification/efficientnet_v2.py +++ b/src/otx/algo/classification/efficientnet_v2.py @@ -428,7 +428,6 @@ def _build_model(self, head_config: dict) -> nn.Module: multilabel_loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"), **head_config, ), - optimize_gap=False, ) def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.") -> dict: diff --git a/src/otx/algo/classification/mobilenet_v3.py b/src/otx/algo/classification/mobilenet_v3.py index 5697427f4b7..d2696f59ce8 100644 --- a/src/otx/algo/classification/mobilenet_v3.py +++ b/src/otx/algo/classification/mobilenet_v3.py @@ -332,7 +332,6 @@ def _build_model(self, head_config: dict) -> nn.Module: multilabel_loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"), **head_config, ), - optimize_gap=False, ) def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.") -> dict: diff --git a/src/otx/algo/classification/torchvision_model.py b/src/otx/algo/classification/torchvision_model.py index 4702ab63c1a..b9e3c2a973e 100644 --- a/src/otx/algo/classification/torchvision_model.py +++ b/src/otx/algo/classification/torchvision_model.py @@ -209,12 +209,13 @@ def _get_head(self, net: nn.Module) -> nn.Module: loss=self.loss_module, ) if self.task == OTXTaskType.H_LABEL_CLS: - self.neck = nn.Sequential(*layers) if layers else None + self.neck = nn.Sequential(*layers, nn.Identity()) if layers else None return HierarchicalCBAMClsHead( in_channels=feature_channel, multiclass_loss=nn.CrossEntropyLoss(), multilabel_loss=self.loss_module, **self.head_config, + step_size=1, ) msg = f"Task type {self.task} is not supported." From 245b952af7f082d9450a547507eb0b1088a341dd Mon Sep 17 00:00:00 2001 From: sooahleex Date: Sat, 10 Aug 2024 16:05:33 +0900 Subject: [PATCH 12/16] Update CHANGELOG --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 63b53d13b83..11ed67e7ae3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,8 @@ All notable changes to this project will be documented in this file. (https://github.com/openvinotoolkit/training_extensions/pull/3762) - Add RTMPose for Keypoint Detection Task (https://github.com/openvinotoolkit/training_extensions/pull/3781) +- Update head and h-label format for hierarchical label classification + (https://github.com/openvinotoolkit/training_extensions/pull/3810) ### Enhancements From 0d6dcfca0da469a698b9d7577caea626390d0818 Mon Sep 17 00:00:00 2001 From: sooahleex Date: Sat, 10 Aug 2024 16:37:18 +0900 Subject: [PATCH 13/16] Use optimize_gap for classifier of h-label models --- .../classifier/base_classifier.py | 2 + src/otx/algo/classification/efficientnet.py | 3 +- .../algo/classification/efficientnet_v2.py | 531 ------------------ src/otx/algo/classification/mobilenet_v3.py | 1 + src/otx/algo/classification/timm_model.py | 7 +- .../algo/classification/test_timm_model.py | 4 +- 6 files changed, 11 insertions(+), 537 deletions(-) delete mode 100644 src/otx/algo/classification/efficientnet_v2.py diff --git a/src/otx/algo/classification/classifier/base_classifier.py b/src/otx/algo/classification/classifier/base_classifier.py index 062bbca450b..3c5126824b2 100644 --- a/src/otx/algo/classification/classifier/base_classifier.py +++ b/src/otx/algo/classification/classifier/base_classifier.py @@ -61,6 +61,7 @@ def __init__( neck: nn.Module | None, head: nn.Module, pretrained: str | None = None, + optimize_gap: bool = True, mean: list[float] | None = None, std: list[float] | None = None, to_rgb: bool = False, @@ -81,6 +82,7 @@ def __init__( self.explainer = ReciproCAM( self._head_forward_fn, num_classes=head.num_classes, + optimize_gap=optimize_gap, ) def forward( diff --git a/src/otx/algo/classification/efficientnet.py b/src/otx/algo/classification/efficientnet.py index 35245a550f7..46ac9597f20 100644 --- a/src/otx/algo/classification/efficientnet.py +++ b/src/otx/algo/classification/efficientnet.py @@ -267,11 +267,12 @@ def _build_model(self, head_config: dict) -> nn.Module: backbone=backbone, neck=nn.Identity(), head=HierarchicalCBAMClsHead( - in_channels=1280, + in_channels=backbone.num_features, multiclass_loss=nn.CrossEntropyLoss(), multilabel_loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"), **head_config, ), + optimize_gap=False, ) def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.") -> dict: diff --git a/src/otx/algo/classification/efficientnet_v2.py b/src/otx/algo/classification/efficientnet_v2.py deleted file mode 100644 index a8b1bfb2fde..00000000000 --- a/src/otx/algo/classification/efficientnet_v2.py +++ /dev/null @@ -1,531 +0,0 @@ -# Copyright (C) 2023 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 -# -"""EfficientNetV2 model implementation.""" -from __future__ import annotations - -from copy import deepcopy -from typing import TYPE_CHECKING, Any - -import torch -from torch import Tensor, nn - -from otx.algo.classification.backbones.timm import TimmBackbone -from otx.algo.classification.classifier import ImageClassifier, SemiSLClassifier -from otx.algo.classification.heads import ( - HierarchicalCBAMClsHead, - LinearClsHead, - MultiLabelLinearClsHead, - OTXSemiSLLinearClsHead, -) -from otx.algo.classification.losses.asymmetric_angular_loss_with_ignore import AsymmetricAngularLossWithIgnore -from otx.algo.classification.necks.gap import GlobalAveragePooling -from otx.algo.classification.utils import get_classification_layers -from otx.algo.utils.support_otx_v1 import OTXv1Helper -from otx.core.data.entity.base import OTXBatchLossEntity -from otx.core.data.entity.classification import ( - HlabelClsBatchDataEntity, - HlabelClsBatchPredEntity, - MulticlassClsBatchDataEntity, - MulticlassClsBatchPredEntity, - MultilabelClsBatchDataEntity, - MultilabelClsBatchPredEntity, -) -from otx.core.exporter.base import OTXModelExporter -from otx.core.exporter.native import OTXNativeModelExporter -from otx.core.metrics import MetricInput -from otx.core.metrics.accuracy import HLabelClsMetricCallble, MultiClassClsMetricCallable, MultiLabelClsMetricCallable -from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable -from otx.core.model.classification import ( - OTXHlabelClsModel, - OTXMulticlassClsModel, - OTXMultilabelClsModel, -) -from otx.core.schedulers import LRSchedulerListCallable -from otx.core.types.label import HLabelInfo, LabelInfoTypes - -if TYPE_CHECKING: - from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable - - from otx.core.metrics import MetricCallable - - -class EfficientNetV2ForMulticlassCls(OTXMulticlassClsModel): - """EfficientNetV2 Model for multi-class classification task.""" - - def __init__( - self, - label_info: LabelInfoTypes, - optimizer: OptimizerCallable = DefaultOptimizerCallable, - scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable, - metric: MetricCallable = MultiClassClsMetricCallable, - torch_compile: bool = False, - ) -> None: - super().__init__( - label_info=label_info, - optimizer=optimizer, - scheduler=scheduler, - metric=metric, - torch_compile=torch_compile, - ) - - def _create_model(self) -> nn.Module: - # Get classification_layers for class-incr learning - sample_model_dict = self._build_model(num_classes=5).state_dict() - incremental_model_dict = self._build_model(num_classes=6).state_dict() - self.classification_layers = get_classification_layers( - sample_model_dict, - incremental_model_dict, - prefix="model.", - ) - - model = self._build_model(num_classes=self.num_classes) - model.init_weights() - return model - - def _build_model(self, num_classes: int) -> nn.Module: - return ImageClassifier( - backbone=TimmBackbone(backbone="efficientnetv2_s_21k", pretrained=True), - neck=GlobalAveragePooling(dim=2), - head=LinearClsHead( - num_classes=num_classes, - in_channels=1280, - topk=(1, 5) if num_classes >= 5 else (1,), - loss=nn.CrossEntropyLoss(reduction="none"), - ), - ) - - def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.") -> dict: - """Load the previous OTX ckpt according to OTX2.0.""" - return OTXv1Helper.load_cls_effnet_v2_ckpt(state_dict, "multiclass", add_prefix) - - def _customize_inputs(self, inputs: MulticlassClsBatchDataEntity) -> dict[str, Any]: - if self.training: - mode = "loss" - elif self.explain_mode: - mode = "explain" - else: - mode = "predict" - - return { - "images": inputs.stacked_images, - "labels": torch.cat(inputs.labels, dim=0), - "imgs_info": inputs.imgs_info, - "mode": mode, - } - - def _customize_outputs( - self, - outputs: Any, # noqa: ANN401 - inputs: MulticlassClsBatchDataEntity, - ) -> MulticlassClsBatchPredEntity | OTXBatchLossEntity: - if self.training: - return OTXBatchLossEntity(loss=outputs) - - # To list, batch-wise - logits = outputs if isinstance(outputs, torch.Tensor) else outputs["logits"] - scores = torch.unbind(logits, 0) - preds = logits.argmax(-1, keepdim=True).unbind(0) - - return MulticlassClsBatchPredEntity( - batch_size=inputs.batch_size, - images=inputs.images, - imgs_info=inputs.imgs_info, - scores=scores, - labels=preds, - ) - - @property - def _exporter(self) -> OTXModelExporter: - """Creates OTXModelExporter object that can export the model.""" - return OTXNativeModelExporter( - task_level_export_parameters=self._export_parameters, - input_size=self.image_size, - mean=(123.675, 116.28, 103.53), - std=(58.395, 57.12, 57.375), - resize_mode="standard", - pad_value=0, - swap_rgb=False, - via_onnx=False, - onnx_export_configuration=None, - output_names=["logits", "feature_vector", "saliency_map"] if self.explain_mode else None, - ) - - def forward_explain(self, inputs: MulticlassClsBatchDataEntity) -> MulticlassClsBatchPredEntity: - """Model forward explain function.""" - outputs = self.model(images=inputs.stacked_images, mode="explain") - - return MulticlassClsBatchPredEntity( - batch_size=len(outputs["preds"]), - images=inputs.images, - imgs_info=inputs.imgs_info, - labels=outputs["preds"], - scores=outputs["scores"], - saliency_map=outputs["saliency_map"], - feature_vector=outputs["feature_vector"], - ) - - def forward_for_tracing(self, image: torch.Tensor) -> torch.Tensor | dict[str, torch.Tensor]: - """Model forward function used for the model tracing during model exportation.""" - if self.explain_mode: - return self.model(images=image, mode="explain") - - return self.model(images=image, mode="tensor") - - -class EfficientNetV2ForMulticlassClsSemiSL(EfficientNetV2ForMulticlassCls): - """EfficientNetV2 model for multiclass classification with semi-supervised learning. - - This class extends the `EfficientNetV2ForMulticlassCls` class and adds support for semi-supervised learning. - It overrides the `_build_model` and `_customize_inputs` methods to incorporate the semi-supervised learning. - - Args: - EfficientNetV2ForMulticlassCls (class): The base class for EfficientNetV2 multiclass classification. - """ - - def _build_model(self, num_classes: int) -> nn.Module: - return SemiSLClassifier( - backbone=TimmBackbone(backbone="efficientnetv2_s_21k", pretrained=True), - neck=GlobalAveragePooling(dim=2), - head=OTXSemiSLLinearClsHead( - num_classes=num_classes, - in_channels=1280, - loss=nn.CrossEntropyLoss(reduction="none"), - ), - ) - - def _customize_inputs(self, inputs: MulticlassClsBatchDataEntity) -> dict[str, Any]: - """Customizes the input data for the model based on the current mode. - - Args: - inputs (MulticlassClsBatchDataEntity): The input batch of data. - - Returns: - dict[str, Any]: The customized input data. - """ - if self.training: - mode = "loss" - elif self.explain_mode: - mode = "explain" - else: - mode = "predict" - - if isinstance(inputs, dict): - # When used with an unlabeled dataset, it comes in as a dict. - images = {key: inputs[key].images for key in inputs} - labels = {key: torch.cat(inputs[key].labels, dim=0) for key in inputs} - imgs_info = {key: inputs[key].imgs_info for key in inputs} - return { - "images": images, - "labels": labels, - "imgs_info": imgs_info, - "mode": mode, - } - return { - "images": inputs.stacked_images, - "labels": torch.cat(inputs.labels, dim=0), - "imgs_info": inputs.imgs_info, - "mode": mode, - } - - def training_step(self, batch: MulticlassClsBatchDataEntity, batch_idx: int) -> Tensor: - """Performs a single training step on a batch of data. - - Args: - batch (MulticlassClsBatchDataEntity): The input batch of data. - batch_idx (int): The index of the current batch. - - Returns: - Tensor: The computed loss for the training step. - """ - loss = super().training_step(batch, batch_idx) - # Collect metrics related to Semi-SL Training. - self.log( - "train/unlabeled_coef", - self.model.head.unlabeled_coef, - on_step=True, - on_epoch=False, - prog_bar=True, - ) - self.log( - "train/num_pseudo_label", - self.model.head.num_pseudo_label, - on_step=True, - on_epoch=False, - prog_bar=True, - ) - return loss - - -class EfficientNetV2ForMultilabelCls(OTXMultilabelClsModel): - """EfficientNetV2 Model for multi-label classification task.""" - - def __init__( - self, - label_info: LabelInfoTypes, - optimizer: OptimizerCallable = DefaultOptimizerCallable, - scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable, - metric: MetricCallable = MultiLabelClsMetricCallable, - torch_compile: bool = False, - ) -> None: - super().__init__( - label_info=label_info, - optimizer=optimizer, - scheduler=scheduler, - metric=metric, - torch_compile=torch_compile, - ) - - def _create_model(self) -> nn.Module: - # Get classification_layers for class-incr learning - sample_model_dict = self._build_model(num_classes=5).state_dict() - incremental_model_dict = self._build_model(num_classes=6).state_dict() - self.classification_layers = get_classification_layers( - sample_model_dict, - incremental_model_dict, - prefix="model.", - ) - - model = self._build_model(num_classes=self.num_classes) - model.init_weights() - return model - - def _build_model(self, num_classes: int) -> nn.Module: - return ImageClassifier( - backbone=TimmBackbone(backbone="efficientnetv2_s_21k", pretrained=True), - neck=GlobalAveragePooling(dim=2), - head=MultiLabelLinearClsHead( - num_classes=num_classes, - in_channels=1280, - normalized=True, - scale=7.0, - loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"), - ), - ) - - def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.") -> dict: - """Load the previous OTX ckpt according to OTX2.0.""" - return OTXv1Helper.load_cls_effnet_v2_ckpt(state_dict, "multilabel", add_prefix) - - def _customize_inputs(self, inputs: MultilabelClsBatchDataEntity) -> dict[str, Any]: - if self.training: - mode = "loss" - elif self.explain_mode: - mode = "explain" - else: - mode = "predict" - - return { - "images": inputs.stacked_images, - "labels": torch.stack(inputs.labels), - "imgs_info": inputs.imgs_info, - "mode": mode, - } - - def _customize_outputs( - self, - outputs: Any, # noqa: ANN401 - inputs: MultilabelClsBatchDataEntity, - ) -> MultilabelClsBatchPredEntity | OTXBatchLossEntity: - if self.training: - return OTXBatchLossEntity(loss=outputs) - - # To list, batch-wise - logits = outputs if isinstance(outputs, torch.Tensor) else outputs["logits"] - scores = torch.unbind(logits, 0) - - return MultilabelClsBatchPredEntity( - batch_size=inputs.batch_size, - images=inputs.images, - imgs_info=inputs.imgs_info, - scores=scores, - labels=logits.argmax(-1, keepdim=True).unbind(0), - ) - - @property - def _exporter(self) -> OTXModelExporter: - """Creates OTXModelExporter object that can export the model.""" - return OTXNativeModelExporter( - task_level_export_parameters=self._export_parameters, - input_size=(1, 3, 224, 224), - mean=(123.675, 116.28, 103.53), - std=(58.395, 57.12, 57.375), - resize_mode="standard", - pad_value=0, - swap_rgb=False, - via_onnx=False, - onnx_export_configuration=None, - output_names=["logits", "feature_vector", "saliency_map"] if self.explain_mode else None, - ) - - def forward_explain(self, inputs: MultilabelClsBatchDataEntity) -> MultilabelClsBatchPredEntity: - """Model forward explain function.""" - outputs = self.model(images=inputs.stacked_images, mode="explain") - - return MultilabelClsBatchPredEntity( - batch_size=len(outputs["preds"]), - images=inputs.images, - imgs_info=inputs.imgs_info, - labels=outputs["preds"], - scores=outputs["scores"], - saliency_map=outputs["saliency_map"], - feature_vector=outputs["feature_vector"], - ) - - def forward_for_tracing(self, image: torch.Tensor) -> torch.Tensor | dict[str, torch.Tensor]: - """Model forward function used for the model tracing during model exportation.""" - if self.explain_mode: - return self.model(images=image, mode="explain") - - return self.model(images=image, mode="tensor") - - -class EfficientNetV2ForHLabelCls(OTXHlabelClsModel): - """EfficientNetV2 Model for hierarchical label classification task.""" - - label_info: HLabelInfo - - def __init__( - self, - label_info: HLabelInfo, - optimizer: OptimizerCallable = DefaultOptimizerCallable, - scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable, - metric: MetricCallable = HLabelClsMetricCallble, - torch_compile: bool = False, - ) -> None: - super().__init__( - label_info=label_info, - optimizer=optimizer, - scheduler=scheduler, - metric=metric, - torch_compile=torch_compile, - ) - - def _create_model(self) -> nn.Module: - # Get classification_layers for class-incr learning - sample_config = deepcopy(self.label_info.as_head_config_dict()) - sample_config["num_classes"] = 5 - sample_model_dict = self._build_model(head_config=sample_config).state_dict() - sample_config["num_classes"] = 6 - incremental_model_dict = self._build_model(head_config=sample_config).state_dict() - self.classification_layers = get_classification_layers( - sample_model_dict, - incremental_model_dict, - prefix="model.", - ) - - model = self._build_model(head_config=self.label_info.as_head_config_dict()) - model.init_weights() - return model - - def _build_model(self, head_config: dict) -> nn.Module: - return ImageClassifier( - backbone=TimmBackbone(backbone="efficientnetv2_s_21k", pretrained=True), - neck=nn.Identity(), - head=HierarchicalCBAMClsHead( - in_channels=1280, - multiclass_loss=nn.CrossEntropyLoss(), - multilabel_loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"), - **head_config, - ), - ) - - def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.") -> dict: - """Load the previous OTX ckpt according to OTX2.0.""" - return OTXv1Helper.load_cls_effnet_v2_ckpt(state_dict, "hlabel", add_prefix) - - def _customize_inputs(self, inputs: HlabelClsBatchDataEntity) -> dict[str, Any]: - if self.training: - mode = "loss" - elif self.explain_mode: - mode = "explain" - else: - mode = "predict" - - return { - "images": inputs.stacked_images, - "labels": torch.stack(inputs.labels), - "imgs_info": inputs.imgs_info, - "mode": mode, - } - - def _customize_outputs( - self, - outputs: Any, # noqa: ANN401 - inputs: HlabelClsBatchDataEntity, - ) -> HlabelClsBatchPredEntity | OTXBatchLossEntity: - if self.training: - return OTXBatchLossEntity(loss=outputs) - - # To list, batch-wise - if isinstance(outputs, dict): - scores = outputs["scores"] - labels = outputs["labels"] - else: - scores = outputs - labels = outputs.argmax(-1, keepdim=True) - - return HlabelClsBatchPredEntity( - batch_size=inputs.batch_size, - images=inputs.images, - imgs_info=inputs.imgs_info, - scores=scores, - labels=labels, - ) - - def _convert_pred_entity_to_compute_metric( - self, - preds: HlabelClsBatchPredEntity, - inputs: HlabelClsBatchDataEntity, - ) -> MetricInput: - hlabel_info: HLabelInfo = self.label_info # type: ignore[assignment] - - _labels = torch.stack(preds.labels) if isinstance(preds.labels, list) else preds.labels - _scores = torch.stack(preds.scores) if isinstance(preds.scores, list) else preds.scores - if hlabel_info.num_multilabel_classes > 0: - preds_multiclass = _labels[:, : hlabel_info.num_multiclass_heads] - preds_multilabel = _scores[:, hlabel_info.num_multiclass_heads :] - pred_result = torch.cat([preds_multiclass, preds_multilabel], dim=1) - else: - pred_result = _labels - return { - "preds": pred_result, - "target": torch.stack(inputs.labels), - } - - @property - def _exporter(self) -> OTXModelExporter: - """Creates OTXModelExporter object that can export the model.""" - return OTXNativeModelExporter( - task_level_export_parameters=self._export_parameters, - input_size=(1, 3, 224, 224), - mean=(123.675, 116.28, 103.53), - std=(58.395, 57.12, 57.375), - resize_mode="standard", - pad_value=0, - swap_rgb=False, - via_onnx=False, - onnx_export_configuration=None, - output_names=["logits", "feature_vector", "saliency_map"] if self.explain_mode else None, - ) - - def forward_explain(self, inputs: HlabelClsBatchDataEntity) -> HlabelClsBatchPredEntity: - """Model forward explain function.""" - outputs = self.model(images=inputs.stacked_images, mode="explain") - - return HlabelClsBatchPredEntity( - batch_size=len(outputs["preds"]), - images=inputs.images, - imgs_info=inputs.imgs_info, - labels=outputs["preds"], - scores=outputs["scores"], - saliency_map=outputs["saliency_map"], - feature_vector=outputs["feature_vector"], - ) - - def forward_for_tracing(self, image: torch.Tensor) -> torch.Tensor | dict[str, torch.Tensor]: - """Model forward function used for the model tracing during model exportation.""" - if self.explain_mode: - return self.model(images=image, mode="explain") - - return self.model(images=image, mode="tensor") diff --git a/src/otx/algo/classification/mobilenet_v3.py b/src/otx/algo/classification/mobilenet_v3.py index d2696f59ce8..5697427f4b7 100644 --- a/src/otx/algo/classification/mobilenet_v3.py +++ b/src/otx/algo/classification/mobilenet_v3.py @@ -332,6 +332,7 @@ def _build_model(self, head_config: dict) -> nn.Module: multilabel_loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"), **head_config, ), + optimize_gap=False, ) def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.") -> dict: diff --git a/src/otx/algo/classification/timm_model.py b/src/otx/algo/classification/timm_model.py index 04eaf5ff396..f8d009fb8ca 100644 --- a/src/otx/algo/classification/timm_model.py +++ b/src/otx/algo/classification/timm_model.py @@ -13,7 +13,7 @@ from otx.algo.classification.backbones.timm import TimmBackbone, TimmModelType from otx.algo.classification.classifier import ImageClassifier, SemiSLClassifier from otx.algo.classification.heads import ( - HierarchicalLinearClsHead, + HierarchicalCBAMClsHead, LinearClsHead, MultiLabelLinearClsHead, OTXSemiSLLinearClsHead, @@ -264,13 +264,14 @@ def _build_model(self, head_config: dict) -> nn.Module: backbone = TimmBackbone(backbone=self.backbone, pretrained=self.pretrained) return ImageClassifier( backbone=backbone, - neck=GlobalAveragePooling(dim=2), - head=HierarchicalLinearClsHead( + neck=nn.Identity(), + head=HierarchicalCBAMClsHead( in_channels=backbone.num_features, multiclass_loss=nn.CrossEntropyLoss(), multilabel_loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"), **head_config, ), + optimize_gap=False, ) def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.") -> dict: diff --git a/tests/unit/algo/classification/test_timm_model.py b/tests/unit/algo/classification/test_timm_model.py index fbb4d6fbbc0..b20bcf7eba9 100644 --- a/tests/unit/algo/classification/test_timm_model.py +++ b/tests/unit/algo/classification/test_timm_model.py @@ -94,9 +94,9 @@ def test_predict_step(self, fxt_multi_label_cls_model, fxt_multilabel_cls_batch_ @pytest.fixture() -def fxt_h_label_cls_model(fxt_hlabel_data): +def fxt_h_label_cls_model(fxt_hlabel_cifar): return TimmModelForHLabelCls( - label_info=fxt_hlabel_data, + label_info=fxt_hlabel_cifar, backbone="efficientnetv2_s_21k", ) From c5e5de564ffbd8d5c44d95ecc0e8a3a9ef91ebc3 Mon Sep 17 00:00:00 2001 From: sooahleex Date: Sat, 10 Aug 2024 16:55:00 +0900 Subject: [PATCH 14/16] Add unit tests for CBAM --- .../heads/test_hlabel_cls_head.py | 64 +++++++++++++++++++ 1 file changed, 64 insertions(+) diff --git a/tests/unit/algo/classification/heads/test_hlabel_cls_head.py b/tests/unit/algo/classification/heads/test_hlabel_cls_head.py index 9edc1f53d30..11e7191dc49 100644 --- a/tests/unit/algo/classification/heads/test_hlabel_cls_head.py +++ b/tests/unit/algo/classification/heads/test_hlabel_cls_head.py @@ -13,6 +13,7 @@ HierarchicalLinearClsHead, HierarchicalNonLinearClsHead, ) +from otx.algo.classification.heads.hlabel_cls_head import CBAM, ChannelAttention, SpatialAttention from otx.algo.classification.losses import AsymmetricAngularLossWithIgnore from otx.core.data.entity.base import ImageInfo from torch import nn @@ -105,3 +106,66 @@ def test_predict( assert result["scores"].shape == (2, 3) assert "labels" in result assert result["labels"].shape == (2, 3) + + +class TestChannelAttention: + @pytest.fixture() + def fxt_channel_attention(self) -> ChannelAttention: + return ChannelAttention(in_channels=64, reduction=16) + + def test_forward(self, fxt_channel_attention) -> None: + input_tensor = torch.rand((8, 64, 32, 32)) + result = fxt_channel_attention(input_tensor) + assert torch.all(result >= 0) + assert torch.all(result <= 1) + + +class TestSpatialAttention: + @pytest.fixture() + def fxt_spatial_attention(self) -> SpatialAttention: + return SpatialAttention(kernel_size=7) + + def test_forward(self, fxt_spatial_attention) -> None: + input_tensor = torch.rand((8, 64, 32, 32)) + result = fxt_spatial_attention(input_tensor) + assert torch.all(result >= 0) + assert torch.all(result <= 1) + + +class TestCBAM: + @pytest.fixture() + def fxt_cbam(self) -> CBAM: + return CBAM(in_channels=64, reduction=16, kernel_size=7) + + def test_forward(self, fxt_cbam) -> None: + input_tensor = torch.rand((8, 64, 32, 32)) + result = fxt_cbam(input_tensor) + assert torch.all(result >= 0) + assert torch.all(result <= 1) + + +class TestHierarchicalCBAMClsHead: + @pytest.fixture() + def fxt_hierarchical_cbam_cls_head(self) -> HierarchicalCBAMClsHead: + head_idx_to_logits_range = {"0": (0, 5), "1": (5, 10), "2": (10, 12)} + return HierarchicalCBAMClsHead( + num_multiclass_heads=3, + num_multilabel_classes=0, + head_idx_to_logits_range=head_idx_to_logits_range, + num_single_label_classes=12, + empty_multiclass_head_indices=[], + in_channels=64, + num_classes=12, + multiclass_loss=CrossEntropyLoss(), + multilabel_loss=None, + ) + + def test_forward(self, fxt_hierarchical_cbam_cls_head) -> None: + input_tensor = torch.rand((8, 64, 7, 7)) + result = fxt_hierarchical_cbam_cls_head(input_tensor) + assert result.shape == (8, 12) + + def test_pre_logits(self, fxt_hierarchical_cbam_cls_head) -> None: + input_tensor = torch.rand((8, 64, 7, 7)) + pre_logits = fxt_hierarchical_cbam_cls_head.pre_logits(input_tensor) + assert pre_logits.shape == (8, 64 * 7 * 7) From 33cc51723ab3feddbdcf65206dd055cfdf4dd339 Mon Sep 17 00:00:00 2001 From: sooahleex Date: Sat, 10 Aug 2024 17:33:15 +0900 Subject: [PATCH 15/16] Update comments --- src/otx/algo/classification/heads/hlabel_cls_head.py | 2 ++ src/otx/algo/classification/vit.py | 1 + 2 files changed, 3 insertions(+) diff --git a/src/otx/algo/classification/heads/hlabel_cls_head.py b/src/otx/algo/classification/heads/hlabel_cls_head.py index 07c0c97df05..f92104df057 100644 --- a/src/otx/algo/classification/heads/hlabel_cls_head.py +++ b/src/otx/algo/classification/heads/hlabel_cls_head.py @@ -418,6 +418,8 @@ class HierarchicalCBAMClsHead(HierarchicalClsHead): multilabel_loss (dict | None): Config of multi-label loss. thr (float | None): Predictions with scores under the thresholds are considered as negative. Defaults to 0.5. + init_cfg (dict | None, optional): Initialize configuration key-values, Defaults to None. + step_size (int, optional): Step size value for HierarchicalCBAMClsHead, Defaults to 7. """ def __init__( diff --git a/src/otx/algo/classification/vit.py b/src/otx/algo/classification/vit.py index a6638249f24..993748aa3d0 100644 --- a/src/otx/algo/classification/vit.py +++ b/src/otx/algo/classification/vit.py @@ -501,5 +501,6 @@ def _build_model(self, head_config: dict) -> nn.Module: step_size=1, **head_config, ), + optimize_gap=False, init_cfg=init_cfg, ) From e5249090e2a60624a32bca558ce177f41c4ce05e Mon Sep 17 00:00:00 2001 From: sooahleex Date: Sun, 11 Aug 2024 12:47:20 +0900 Subject: [PATCH 16/16] Update comments for arguments of CBAM head --- src/otx/algo/classification/heads/hlabel_cls_head.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/otx/algo/classification/heads/hlabel_cls_head.py b/src/otx/algo/classification/heads/hlabel_cls_head.py index f92104df057..1b5767c4ace 100644 --- a/src/otx/algo/classification/heads/hlabel_cls_head.py +++ b/src/otx/algo/classification/heads/hlabel_cls_head.py @@ -408,15 +408,15 @@ class HierarchicalCBAMClsHead(HierarchicalClsHead): Args: num_multiclass_heads (int): Number of multi-class heads. num_multilabel_classes (int): Number of multi-label classes. - head_idx_to_logits_range: the logit range of each heads - num_single_label_classes: the number of single label classes - empty_multiclass_head_indices: the index of head that doesn't include any label + head_idx_to_logits_range (dict[str, tuple[int, int]]): the logit range of each heads + num_single_label_classes (int): the number of single label classes + empty_multiclass_head_indices (list[int]): the index of head that doesn't include any label due to the label removing in_channels (int): Number of channels in the input feature map. num_classes (int): Number of total classes. - multiclass_loss (dict | None): Config of multi-class loss. - multilabel_loss (dict | None): Config of multi-label loss. - thr (float | None): Predictions with scores under the thresholds are considered + multiclass_loss (nn.Module): Config of multi-class loss. + multilabel_loss (nn.Module | None, optional): Config of multi-label loss. + thr (float, optional): Predictions with scores under the thresholds are considered as negative. Defaults to 0.5. init_cfg (dict | None, optional): Initialize configuration key-values, Defaults to None. step_size (int, optional): Step size value for HierarchicalCBAMClsHead, Defaults to 7.