diff --git a/CHANGELOG.md b/CHANGELOG.md index 63b53d13b83..11ed67e7ae3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,8 @@ All notable changes to this project will be documented in this file. (https://github.com/openvinotoolkit/training_extensions/pull/3762) - Add RTMPose for Keypoint Detection Task (https://github.com/openvinotoolkit/training_extensions/pull/3781) +- Update head and h-label format for hierarchical label classification + (https://github.com/openvinotoolkit/training_extensions/pull/3810) ### Enhancements diff --git a/src/otx/algo/classification/classifier/base_classifier.py b/src/otx/algo/classification/classifier/base_classifier.py index 174dee052a0..3c5126824b2 100644 --- a/src/otx/algo/classification/classifier/base_classifier.py +++ b/src/otx/algo/classification/classifier/base_classifier.py @@ -61,6 +61,7 @@ def __init__( neck: nn.Module | None, head: nn.Module, pretrained: str | None = None, + optimize_gap: bool = True, mean: list[float] | None = None, std: list[float] | None = None, to_rgb: bool = False, @@ -81,7 +82,7 @@ def __init__( self.explainer = ReciproCAM( self._head_forward_fn, num_classes=head.num_classes, - optimize_gap=True, + optimize_gap=optimize_gap, ) def forward( diff --git a/src/otx/algo/classification/efficientnet.py b/src/otx/algo/classification/efficientnet.py index ce036ae0aa4..46ac9597f20 100644 --- a/src/otx/algo/classification/efficientnet.py +++ b/src/otx/algo/classification/efficientnet.py @@ -15,7 +15,7 @@ from otx.algo.classification.classifier.base_classifier import ImageClassifier from otx.algo.classification.classifier.semi_sl_classifier import SemiSLClassifier from otx.algo.classification.heads import ( - HierarchicalLinearClsHead, + HierarchicalCBAMClsHead, LinearClsHead, MultiLabelLinearClsHead, OTXSemiSLLinearClsHead, @@ -265,13 +265,14 @@ def _build_model(self, head_config: dict) -> nn.Module: backbone = OTXEfficientNet(version=self.version, pretrained=self.pretrained) return ImageClassifier( backbone=backbone, - neck=GlobalAveragePooling(dim=2), - head=HierarchicalLinearClsHead( + neck=nn.Identity(), + head=HierarchicalCBAMClsHead( in_channels=backbone.num_features, multiclass_loss=nn.CrossEntropyLoss(), multilabel_loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"), **head_config, ), + optimize_gap=False, ) def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.") -> dict: diff --git a/src/otx/algo/classification/heads/__init__.py b/src/otx/algo/classification/heads/__init__.py index aea5e1f0f4c..a920d6782bb 100644 --- a/src/otx/algo/classification/heads/__init__.py +++ b/src/otx/algo/classification/heads/__init__.py @@ -3,7 +3,7 @@ # """Head modules for OTX custom model.""" -from .hlabel_cls_head import HierarchicalLinearClsHead, HierarchicalNonLinearClsHead +from .hlabel_cls_head import HierarchicalCBAMClsHead, HierarchicalLinearClsHead, HierarchicalNonLinearClsHead from .linear_head import LinearClsHead from .multilabel_cls_head import MultiLabelLinearClsHead, MultiLabelNonLinearClsHead from .semi_sl_head import OTXSemiSLLinearClsHead, OTXSemiSLVisionTransformerClsHead @@ -15,6 +15,7 @@ "MultiLabelNonLinearClsHead", "HierarchicalLinearClsHead", "HierarchicalNonLinearClsHead", + "HierarchicalCBAMClsHead", "VisionTransformerClsHead", "OTXSemiSLLinearClsHead", "OTXSemiSLVisionTransformerClsHead", diff --git a/src/otx/algo/classification/heads/hlabel_cls_head.py b/src/otx/algo/classification/heads/hlabel_cls_head.py index b976b83642f..1b5767c4ace 100644 --- a/src/otx/algo/classification/heads/hlabel_cls_head.py +++ b/src/otx/algo/classification/heads/hlabel_cls_head.py @@ -353,3 +353,136 @@ def forward(self, feats: tuple[torch.Tensor] | torch.Tensor) -> torch.Tensor: """The forward process.""" pre_logits = self.pre_logits(feats) return self.classifier(pre_logits) + + +class ChannelAttention(nn.Module): + """Channel attention module that uses average and max pooling to enhance important channels.""" + + def __init__(self, in_channels: int, reduction: int = 16): + """Initializes the ChannelAttention module.""" + super().__init__() + self.fc1 = nn.Conv2d(in_channels, in_channels // reduction, kernel_size=1, bias=False) + self.fc2 = nn.Conv2d(in_channels // reduction, in_channels, kernel_size=1, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Applies channel attention to the input tensor.""" + avg_out = self.fc2(torch.relu(self.fc1(torch.mean(x, dim=2, keepdim=True).mean(dim=3, keepdim=True)))) + max_out = self.fc2(torch.relu(self.fc1(torch.max(x, dim=2, keepdim=True)[0].max(dim=3, keepdim=True)[0]))) + return torch.sigmoid(avg_out + max_out) + + +class SpatialAttention(nn.Module): + """Spatial attention module that uses average and max pooling to enhance important spatial locations.""" + + def __init__(self, kernel_size: int = 7): + """Initializes the SpatialAttention module.""" + super().__init__() + self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Applies spatial attention to the input tensor.""" + avg_out = torch.mean(x, dim=1, keepdim=True) + max_out = torch.max(x, dim=1, keepdim=True)[0] + x = torch.cat([avg_out, max_out], dim=1) + return torch.sigmoid(self.conv(x)) + + +class CBAM(nn.Module): + """CBAM module that applies channel and spatial attention sequentially.""" + + def __init__(self, in_channels: int, reduction: int = 16, kernel_size: int = 7): + """Initializes the CBAM module with channel and spatial attention.""" + super().__init__() + self.channel_attention = ChannelAttention(in_channels, reduction) + self.spatial_attention = SpatialAttention(kernel_size) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Applies channel and spatial attention to the input tensor.""" + x = x * self.channel_attention(x) + return x * self.spatial_attention(x) + + +class HierarchicalCBAMClsHead(HierarchicalClsHead): + """Custom classification CBAM head for hierarchical classification task. + + Args: + num_multiclass_heads (int): Number of multi-class heads. + num_multilabel_classes (int): Number of multi-label classes. + head_idx_to_logits_range (dict[str, tuple[int, int]]): the logit range of each heads + num_single_label_classes (int): the number of single label classes + empty_multiclass_head_indices (list[int]): the index of head that doesn't include any label + due to the label removing + in_channels (int): Number of channels in the input feature map. + num_classes (int): Number of total classes. + multiclass_loss (nn.Module): Config of multi-class loss. + multilabel_loss (nn.Module | None, optional): Config of multi-label loss. + thr (float, optional): Predictions with scores under the thresholds are considered + as negative. Defaults to 0.5. + init_cfg (dict | None, optional): Initialize configuration key-values, Defaults to None. + step_size (int, optional): Step size value for HierarchicalCBAMClsHead, Defaults to 7. + """ + + def __init__( + self, + num_multiclass_heads: int, + num_multilabel_classes: int, + head_idx_to_logits_range: dict[str, tuple[int, int]], + num_single_label_classes: int, + empty_multiclass_head_indices: list[int], + in_channels: int, + num_classes: int, + multiclass_loss: nn.Module, + multilabel_loss: nn.Module | None = None, + thr: float = 0.5, + init_cfg: dict | None = None, + step_size: int = 7, + **kwargs, + ): + super().__init__( + num_multiclass_heads=num_multiclass_heads, + num_multilabel_classes=num_multilabel_classes, + head_idx_to_logits_range=head_idx_to_logits_range, + num_single_label_classes=num_single_label_classes, + empty_multiclass_head_indices=empty_multiclass_head_indices, + in_channels=in_channels, + num_classes=num_classes, + multiclass_loss=multiclass_loss, + multilabel_loss=multilabel_loss, + thr=thr, + init_cfg=init_cfg, + **kwargs, + ) + self.step_size = step_size + self.fc_superclass = nn.Linear(in_channels * step_size * step_size, num_multiclass_heads) + self.attention_fc = nn.Linear(num_multiclass_heads, in_channels * step_size * step_size) + self.cbam = CBAM(in_channels) + self.fc_subclass = nn.Linear(in_channels * step_size * step_size, num_single_label_classes) + + self._init_layers() + + def pre_logits(self, feats: tuple[torch.Tensor] | torch.Tensor) -> torch.Tensor: + """The process before the final classification head.""" + if isinstance(feats, Sequence): + feats = feats[-1] + return feats.view(feats.size(0), self.in_channels * self.step_size * self.step_size) + + def _init_layers(self) -> None: + """Iniitialize weights of classification head.""" + normal_init(self.fc_superclass, mean=0, std=0.01, bias=0) + normal_init(self.fc_subclass, mean=0, std=0.01, bias=0) + + def forward(self, feats: tuple[torch.Tensor] | torch.Tensor) -> torch.Tensor: + """The forward process.""" + pre_logits = self.pre_logits(feats) + out_superclass = self.fc_superclass(pre_logits) + + attention_weights = torch.sigmoid(self.attention_fc(out_superclass)) + attended_features = pre_logits * attention_weights + + attended_features = attended_features.view(pre_logits.size(0), self.in_channels, self.step_size, self.step_size) + attended_features = self.cbam(attended_features) + attended_features = attended_features.view( + pre_logits.size(0), + self.in_channels * self.step_size * self.step_size, + ) + return self.fc_subclass(attended_features) diff --git a/src/otx/algo/classification/mobilenet_v3.py b/src/otx/algo/classification/mobilenet_v3.py index 756b52721ae..5697427f4b7 100644 --- a/src/otx/algo/classification/mobilenet_v3.py +++ b/src/otx/algo/classification/mobilenet_v3.py @@ -15,7 +15,7 @@ from otx.algo.classification.backbones import OTXMobileNetV3 from otx.algo.classification.classifier import ImageClassifier, SemiSLClassifier from otx.algo.classification.heads import ( - HierarchicalNonLinearClsHead, + HierarchicalCBAMClsHead, LinearClsHead, MultiLabelNonLinearClsHead, OTXSemiSLLinearClsHead, @@ -325,13 +325,14 @@ def _build_model(self, head_config: dict) -> nn.Module: return ImageClassifier( backbone=OTXMobileNetV3(mode=self.mode), - neck=GlobalAveragePooling(dim=2), - head=HierarchicalNonLinearClsHead( + neck=nn.Identity(), + head=HierarchicalCBAMClsHead( in_channels=960, multiclass_loss=nn.CrossEntropyLoss(), multilabel_loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"), **head_config, ), + optimize_gap=False, ) def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.") -> dict: diff --git a/src/otx/algo/classification/timm_model.py b/src/otx/algo/classification/timm_model.py index 04eaf5ff396..f8d009fb8ca 100644 --- a/src/otx/algo/classification/timm_model.py +++ b/src/otx/algo/classification/timm_model.py @@ -13,7 +13,7 @@ from otx.algo.classification.backbones.timm import TimmBackbone, TimmModelType from otx.algo.classification.classifier import ImageClassifier, SemiSLClassifier from otx.algo.classification.heads import ( - HierarchicalLinearClsHead, + HierarchicalCBAMClsHead, LinearClsHead, MultiLabelLinearClsHead, OTXSemiSLLinearClsHead, @@ -264,13 +264,14 @@ def _build_model(self, head_config: dict) -> nn.Module: backbone = TimmBackbone(backbone=self.backbone, pretrained=self.pretrained) return ImageClassifier( backbone=backbone, - neck=GlobalAveragePooling(dim=2), - head=HierarchicalLinearClsHead( + neck=nn.Identity(), + head=HierarchicalCBAMClsHead( in_channels=backbone.num_features, multiclass_loss=nn.CrossEntropyLoss(), multilabel_loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"), **head_config, ), + optimize_gap=False, ) def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.") -> dict: diff --git a/src/otx/algo/classification/torchvision_model.py b/src/otx/algo/classification/torchvision_model.py index 7a81b536965..b9e3c2a973e 100644 --- a/src/otx/algo/classification/torchvision_model.py +++ b/src/otx/algo/classification/torchvision_model.py @@ -11,7 +11,7 @@ from torch import Tensor, nn from torchvision.models import get_model, get_model_weights -from otx.algo.classification.heads import HierarchicalLinearClsHead, MultiLabelLinearClsHead, OTXSemiSLLinearClsHead +from otx.algo.classification.heads import HierarchicalCBAMClsHead, MultiLabelLinearClsHead, OTXSemiSLLinearClsHead from otx.algo.classification.losses import AsymmetricAngularLossWithIgnore from otx.algo.explain.explain_algo import ReciproCAM, feature_vector_fn from otx.core.data.entity.base import OTXBatchLossEntity @@ -209,12 +209,13 @@ def _get_head(self, net: nn.Module) -> nn.Module: loss=self.loss_module, ) if self.task == OTXTaskType.H_LABEL_CLS: - self.neck = nn.Sequential(*layers) if layers else None - return HierarchicalLinearClsHead( + self.neck = nn.Sequential(*layers, nn.Identity()) if layers else None + return HierarchicalCBAMClsHead( in_channels=feature_channel, multiclass_loss=nn.CrossEntropyLoss(), multilabel_loss=self.loss_module, **self.head_config, + step_size=1, ) msg = f"Task type {self.task} is not supported." diff --git a/src/otx/algo/classification/vit.py b/src/otx/algo/classification/vit.py index f2ccd09b8d9..993748aa3d0 100644 --- a/src/otx/algo/classification/vit.py +++ b/src/otx/algo/classification/vit.py @@ -18,7 +18,7 @@ from otx.algo.classification.backbones.vision_transformer import VIT_ARCH_TYPE, VisionTransformer from otx.algo.classification.classifier import ImageClassifier, SemiSLClassifier from otx.algo.classification.heads import ( - HierarchicalLinearClsHead, + HierarchicalCBAMClsHead, MultiLabelLinearClsHead, OTXSemiSLVisionTransformerClsHead, VisionTransformerClsHead, @@ -494,11 +494,13 @@ def _build_model(self, head_config: dict) -> nn.Module: return ImageClassifier( backbone=vit_backbone, neck=None, - head=HierarchicalLinearClsHead( + head=HierarchicalCBAMClsHead( in_channels=vit_backbone.embed_dim, multiclass_loss=nn.CrossEntropyLoss(), multilabel_loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"), + step_size=1, **head_config, ), + optimize_gap=False, init_cfg=init_cfg, ) diff --git a/src/otx/core/data/dataset/classification.py b/src/otx/core/data/dataset/classification.py index a26ffc2799e..57170da967b 100644 --- a/src/otx/core/data/dataset/classification.py +++ b/src/otx/core/data/dataset/classification.py @@ -145,7 +145,7 @@ def _get_label_group_idx(label_name: str) -> int: def _find_ancestor_recursively(label_name: str, ancestors: list) -> list[str]: _, dm_label_category = self.dm_categories.find(label_name) - parent_name = dm_label_category.parent + parent_name = dm_label_category.parent if dm_label_category else "" if parent_name != "": ancestors.append(parent_name) @@ -200,21 +200,22 @@ def _convert_label_to_hlabel_format(self, label_anns: list[Label], ignored_label """Convert format of the label to the h-label. It converts the label format to h-label format. + Total length of result is sum of number of hierarchy and number of multilabel classes. i.e. Let's assume that we used the same dataset with example of the definition of HLabelData - and the original labels are ["Rigid", "Panda", "Lion"]. + and the original labels are ["Rigid", "Triangle", "Lion"]. - Then, h-label format will be [1, -1, 0, 1, 1]. + Then, h-label format will be [0, 1, 1, 0]. The first N-th indices represent the label index of multiclass heads (N=num_multiclass_heads), others represent the multilabel labels. - [Multiclass Heads: [1, -1]] - 0-th index = 1 -> ["Non-Rigid"(X), "Rigid"(O)] <- First multiclass head - 1-st index = -1 -> ["Rectangle"(X), "Triangle"(X)] <- Second multiclass head + [Multiclass Heads] + 0-th index = 0 -> ["Rigid"(O), "Non-Rigid"(X)] <- First multiclass head + 1-st index = 1 -> ["Rectangle"(O), "Triangle"(X), "Circle"(X)] <- Second multiclass head - [Multilabel Head: [0, 1, 1]] - 2, 3, 4 indices = [0, 1, 1] -> ["Circle"(X), "Lion"(O), "Panda"(O)] + [Multilabel Head] + 2, 3 indices = [1, 0] -> ["Lion"(O), "Panda"(X)] """ if not isinstance(self.label_info, HLabelInfo): msg = f"The type of label_info should be HLabelInfo, got {type(self.label_info)}." @@ -229,10 +230,16 @@ def _convert_label_to_hlabel_format(self, label_anns: list[Label], ignored_label for ann in label_anns: ann_name = self.dm_categories.items[ann.label].name + ann_parent = self.dm_categories.items[ann.label].parent group_idx, in_group_idx = self.label_info.class_to_group_idx[ann_name] + (parent_group_idx, parent_in_group_idx) = ( + self.label_info.class_to_group_idx[ann_parent] if ann_parent else (None, None) + ) if group_idx < num_multiclass_heads: class_indices[group_idx] = in_group_idx + if parent_group_idx is not None and parent_in_group_idx is not None: + class_indices[parent_group_idx] = parent_in_group_idx elif not ignored_labels or ann.label not in ignored_labels: class_indices[num_multiclass_heads + in_group_idx] = 1 else: diff --git a/tests/unit/algo/classification/conftest.py b/tests/unit/algo/classification/conftest.py index 825ce7cabb1..945c3d0bc4c 100644 --- a/tests/unit/algo/classification/conftest.py +++ b/tests/unit/algo/classification/conftest.py @@ -132,6 +132,81 @@ def fxt_hlabel_multilabel_info() -> HLabelInfo: ) +@pytest.fixture() +def fxt_hlabel_cifar() -> HLabelInfo: + return HLabelInfo( + label_names=[ + "beaver", + "dolphin", + "otter", + "seal", + "whale", + "aquarium_fish", + "flatfish", + "ray", + "shark", + "trout", + "aquatic_mammals", + "fish", + ], + label_groups=[ + ["beaver", "dolphin", "otter", "seal", "whale"], + ["aquarium_fish", "flatfish", "ray", "shark", "trout"], + ["aquatic_mammals", "fish"], + ], + num_multiclass_heads=3, + num_multilabel_classes=0, + head_idx_to_logits_range={"0": (0, 5), "1": (5, 10), "2": (10, 12)}, + num_single_label_classes=12, + empty_multiclass_head_indices=[], + class_to_group_idx={ + "beaver": (0, 0), + "dolphin": (0, 1), + "otter": (0, 2), + "seal": (0, 3), + "whale": (0, 4), + "aquarium_fish": (1, 0), + "flatfish": (1, 1), + "ray": (1, 2), + "shark": (1, 3), + "trout": (1, 4), + "aquatic_mammals": (2, 0), + "fish": (2, 1), + }, + all_groups=[ + ["beaver", "dolphin", "otter", "seal", "whale"], + ["aquarium_fish", "flatfish", "ray", "shark", "trout"], + ["aquatic_mammals", "fish"], + ], + label_to_idx={ + "aquarium_fish": 0, + "beaver": 1, + "dolphin": 2, + "flatfish": 3, + "otter": 4, + "ray": 5, + "seal": 6, + "shark": 7, + "trout": 8, + "whale": 9, + "aquatic_mammals": 10, + "fish": 11, + }, + label_tree_edges=[ + ["aquarium_fish", "fish"], + ["beaver", "aquatic_mammals"], + ["dolphin", "aquatic_mammals"], + ["otter", "aquatic_mammals"], + ["seal", "aquatic_mammals"], + ["whale", "aquatic_mammals"], + ["flatfish", "aquarium_fish"], + ["ray", "aquarium_fish"], + ["shark", "aquarium_fish"], + ["trout", "aquarium_fish"], + ], + ) + + @pytest.fixture() def fxt_multiclass_cls_batch_data_entity() -> MulticlassClsBatchDataEntity: batch_size = 2 diff --git a/tests/unit/algo/classification/heads/test_hlabel_cls_head.py b/tests/unit/algo/classification/heads/test_hlabel_cls_head.py index 0b7880b9d96..11e7191dc49 100644 --- a/tests/unit/algo/classification/heads/test_hlabel_cls_head.py +++ b/tests/unit/algo/classification/heads/test_hlabel_cls_head.py @@ -8,7 +8,12 @@ import pytest import torch -from otx.algo.classification.heads import HierarchicalLinearClsHead, HierarchicalNonLinearClsHead +from otx.algo.classification.heads import ( + HierarchicalCBAMClsHead, + HierarchicalLinearClsHead, + HierarchicalNonLinearClsHead, +) +from otx.algo.classification.heads.hlabel_cls_head import CBAM, ChannelAttention, SpatialAttention from otx.algo.classification.losses import AsymmetricAngularLossWithIgnore from otx.core.data.entity.base import ImageInfo from torch import nn @@ -49,9 +54,9 @@ def fxt_data_sample_with_ignored_labels() -> dict: class TestHierarchicalLinearClsHead: @pytest.fixture() - def fxt_head_attrs(self, fxt_hlabel_multilabel_info) -> dict[str, Any]: + def fxt_head_attrs(self, fxt_hlabel_cifar) -> dict[str, Any]: return { - **fxt_hlabel_multilabel_info.as_head_config_dict(), + **fxt_hlabel_cifar.as_head_config_dict(), "in_channels": 24, "multiclass_loss": CrossEntropyLoss(), "multilabel_loss": AsymmetricAngularLossWithIgnore(), @@ -65,7 +70,12 @@ def fxt_hlabel_linear_head(self, fxt_head_attrs) -> nn.Module: def fxt_hlabel_non_linear_head(self, fxt_head_attrs) -> nn.Module: return HierarchicalNonLinearClsHead(**fxt_head_attrs) - @pytest.fixture(params=["fxt_hlabel_linear_head", "fxt_hlabel_non_linear_head"]) + @pytest.fixture() + def fxt_hlabel_cbam_head(self, fxt_head_attrs) -> nn.Module: + fxt_head_attrs["step_size"] = 1 + return HierarchicalCBAMClsHead(**fxt_head_attrs) + + @pytest.fixture(params=["fxt_hlabel_linear_head", "fxt_hlabel_non_linear_head", "fxt_hlabel_cbam_head"]) def fxt_hlabel_head(self, request) -> nn.Module: return request.getfixturevalue(request.param) @@ -93,6 +103,69 @@ def test_predict( result = fxt_hlabel_head.predict(dummy_input, **fxt_data_sample) assert isinstance(result, dict) assert "scores" in result - assert result["scores"].shape == (2, 6) + assert result["scores"].shape == (2, 3) assert "labels" in result - assert result["labels"].shape == (2, 6) + assert result["labels"].shape == (2, 3) + + +class TestChannelAttention: + @pytest.fixture() + def fxt_channel_attention(self) -> ChannelAttention: + return ChannelAttention(in_channels=64, reduction=16) + + def test_forward(self, fxt_channel_attention) -> None: + input_tensor = torch.rand((8, 64, 32, 32)) + result = fxt_channel_attention(input_tensor) + assert torch.all(result >= 0) + assert torch.all(result <= 1) + + +class TestSpatialAttention: + @pytest.fixture() + def fxt_spatial_attention(self) -> SpatialAttention: + return SpatialAttention(kernel_size=7) + + def test_forward(self, fxt_spatial_attention) -> None: + input_tensor = torch.rand((8, 64, 32, 32)) + result = fxt_spatial_attention(input_tensor) + assert torch.all(result >= 0) + assert torch.all(result <= 1) + + +class TestCBAM: + @pytest.fixture() + def fxt_cbam(self) -> CBAM: + return CBAM(in_channels=64, reduction=16, kernel_size=7) + + def test_forward(self, fxt_cbam) -> None: + input_tensor = torch.rand((8, 64, 32, 32)) + result = fxt_cbam(input_tensor) + assert torch.all(result >= 0) + assert torch.all(result <= 1) + + +class TestHierarchicalCBAMClsHead: + @pytest.fixture() + def fxt_hierarchical_cbam_cls_head(self) -> HierarchicalCBAMClsHead: + head_idx_to_logits_range = {"0": (0, 5), "1": (5, 10), "2": (10, 12)} + return HierarchicalCBAMClsHead( + num_multiclass_heads=3, + num_multilabel_classes=0, + head_idx_to_logits_range=head_idx_to_logits_range, + num_single_label_classes=12, + empty_multiclass_head_indices=[], + in_channels=64, + num_classes=12, + multiclass_loss=CrossEntropyLoss(), + multilabel_loss=None, + ) + + def test_forward(self, fxt_hierarchical_cbam_cls_head) -> None: + input_tensor = torch.rand((8, 64, 7, 7)) + result = fxt_hierarchical_cbam_cls_head(input_tensor) + assert result.shape == (8, 12) + + def test_pre_logits(self, fxt_hierarchical_cbam_cls_head) -> None: + input_tensor = torch.rand((8, 64, 7, 7)) + pre_logits = fxt_hierarchical_cbam_cls_head.pre_logits(input_tensor) + assert pre_logits.shape == (8, 64 * 7 * 7) diff --git a/tests/unit/algo/classification/test_deit_tiny.py b/tests/unit/algo/classification/test_deit_tiny.py index 420aa50ea35..908c91bb60e 100644 --- a/tests/unit/algo/classification/test_deit_tiny.py +++ b/tests/unit/algo/classification/test_deit_tiny.py @@ -18,7 +18,7 @@ class TestDeitTiny: params=[ (VisionTransformerForMulticlassCls, "fxt_multiclass_cls_batch_data_entity", "fxt_multiclass_labelinfo"), (VisionTransformerForMultilabelCls, "fxt_multilabel_cls_batch_data_entity", "fxt_multilabel_labelinfo"), - (VisionTransformerForHLabelCls, "fxt_hlabel_cls_batch_data_entity", "fxt_hlabel_data"), + (VisionTransformerForHLabelCls, "fxt_hlabel_cls_batch_data_entity", "fxt_hlabel_cifar"), ], ids=["multiclass", "multilabel", "hlabel"], ) diff --git a/tests/unit/algo/classification/test_efficientnet.py b/tests/unit/algo/classification/test_efficientnet.py index 49d16527f7a..b2f2edc688a 100644 --- a/tests/unit/algo/classification/test_efficientnet.py +++ b/tests/unit/algo/classification/test_efficientnet.py @@ -94,10 +94,10 @@ def test_predict_step(self, fxt_multi_label_cls_model, fxt_multilabel_cls_batch_ @pytest.fixture() -def fxt_h_label_cls_model(fxt_hlabel_data): +def fxt_h_label_cls_model(fxt_hlabel_cifar): return EfficientNetForHLabelCls( version="b0", - label_info=fxt_hlabel_data, + label_info=fxt_hlabel_cifar, ) diff --git a/tests/unit/algo/classification/test_mobilenet_v3.py b/tests/unit/algo/classification/test_mobilenet_v3.py index 60981098e1c..62ebcb6ed03 100644 --- a/tests/unit/algo/classification/test_mobilenet_v3.py +++ b/tests/unit/algo/classification/test_mobilenet_v3.py @@ -94,10 +94,10 @@ def test_predict_step(self, fxt_multi_label_cls_model, fxt_multilabel_cls_batch_ @pytest.fixture() -def fxt_h_label_cls_model(fxt_hlabel_data): +def fxt_h_label_cls_model(fxt_hlabel_cifar): return MobileNetV3ForHLabelCls( mode="large", - label_info=fxt_hlabel_data, + label_info=fxt_hlabel_cifar, ) diff --git a/tests/unit/algo/classification/test_timm_model.py b/tests/unit/algo/classification/test_timm_model.py index fbb4d6fbbc0..b20bcf7eba9 100644 --- a/tests/unit/algo/classification/test_timm_model.py +++ b/tests/unit/algo/classification/test_timm_model.py @@ -94,9 +94,9 @@ def test_predict_step(self, fxt_multi_label_cls_model, fxt_multilabel_cls_batch_ @pytest.fixture() -def fxt_h_label_cls_model(fxt_hlabel_data): +def fxt_h_label_cls_model(fxt_hlabel_cifar): return TimmModelForHLabelCls( - label_info=fxt_hlabel_data, + label_info=fxt_hlabel_cifar, backbone="efficientnetv2_s_21k", )