diff --git a/src/otx/algo/anomaly/padim.py b/src/otx/algo/anomaly/padim.py index 201b0230a02..ab9a6ddb1a3 100644 --- a/src/otx/algo/anomaly/padim.py +++ b/src/otx/algo/anomaly/padim.py @@ -9,21 +9,21 @@ from typing import TYPE_CHECKING, Literal +from anomalib.callbacks.normalization.min_max_normalization import _MinMaxNormalizationCallback +from anomalib.callbacks.post_processor import _PostProcessorCallback from anomalib.models.image import Padim as AnomalibPadim from otx.core.model.anomaly import OTXAnomaly -from otx.core.model.base import OTXModel -from otx.core.types.label import AnomalyLabelInfo from otx.core.types.task import OTXTaskType if TYPE_CHECKING: from lightning.pytorch.utilities.types import STEP_OUTPUT from torch.optim.optimizer import Optimizer - from otx.core.model.anomaly import AnomalyModelInputs + from otx.core.model.anomaly import AnomalyModelInputs, AnomalyModelOutputs -class Padim(OTXAnomaly, OTXModel, AnomalibPadim): +class Padim(OTXAnomaly, AnomalibPadim): """OTX Padim model. Args: @@ -49,7 +49,6 @@ def __init__( ] = OTXTaskType.ANOMALY_CLASSIFICATION, ) -> None: OTXAnomaly.__init__(self) - OTXModel.__init__(self, label_info=AnomalyLabelInfo()) AnomalibPadim.__init__( self, backbone=backbone, @@ -132,3 +131,16 @@ def predict_step( if not isinstance(inputs, dict): inputs = self._customize_inputs(inputs) return AnomalibPadim.predict_step(self, inputs, batch_idx, **kwargs) # type: ignore[misc] + + def forward( + self, + inputs: AnomalyModelInputs, + ) -> AnomalyModelOutputs: + """Wrap forward method of the Anomalib model.""" + outputs = self.validation_step(inputs) + # TODO(Ashwin): update forward implementation to comply with other OTX models + _PostProcessorCallback._post_process(outputs) # noqa: SLF001 + _PostProcessorCallback._compute_scores_and_labels(self, outputs) # noqa: SLF001 + _MinMaxNormalizationCallback._normalize_batch(outputs, self) # noqa: SLF001 + + return self._customize_outputs(outputs=outputs, inputs=inputs) diff --git a/src/otx/algo/anomaly/stfpm.py b/src/otx/algo/anomaly/stfpm.py index 72dd30e8aa3..c9ddb4cd93c 100644 --- a/src/otx/algo/anomaly/stfpm.py +++ b/src/otx/algo/anomaly/stfpm.py @@ -9,21 +9,21 @@ from typing import TYPE_CHECKING, Literal, Sequence +from anomalib.callbacks.normalization.min_max_normalization import _MinMaxNormalizationCallback +from anomalib.callbacks.post_processor import _PostProcessorCallback from anomalib.models.image.stfpm import Stfpm as AnomalibStfpm from otx.core.model.anomaly import OTXAnomaly -from otx.core.model.base import OTXModel -from otx.core.types.label import AnomalyLabelInfo from otx.core.types.task import OTXTaskType if TYPE_CHECKING: from lightning.pytorch.utilities.types import STEP_OUTPUT from torch.optim.optimizer import Optimizer - from otx.core.model.anomaly import AnomalyModelInputs + from otx.core.model.anomaly import AnomalyModelInputs, AnomalyModelOutputs -class Stfpm(OTXAnomaly, OTXModel, AnomalibStfpm): +class Stfpm(OTXAnomaly, AnomalibStfpm): """OTX STFPM model. Args: @@ -46,7 +46,6 @@ def __init__( **kwargs, ) -> None: OTXAnomaly.__init__(self) - OTXModel.__init__(self, label_info=AnomalyLabelInfo()) AnomalibStfpm.__init__( self, backbone=backbone, @@ -124,3 +123,16 @@ def predict_step( if not isinstance(inputs, dict): inputs = self._customize_inputs(inputs) return AnomalibStfpm.predict_step(self, inputs, batch_idx, **kwargs) # type: ignore[misc] + + def forward( + self, + inputs: AnomalyModelInputs, + ) -> AnomalyModelOutputs: + """Wrap forward method of the Anomalib model.""" + outputs = self.validation_step(inputs) + # TODO(Ashwin): update forward implementation to comply with other OTX models + _PostProcessorCallback._post_process(outputs) # noqa: SLF001 + _PostProcessorCallback._compute_scores_and_labels(self, outputs) # noqa: SLF001 + _MinMaxNormalizationCallback._normalize_batch(outputs, self) # noqa: SLF001 + + return self._customize_outputs(outputs=outputs, inputs=inputs) diff --git a/src/otx/algo/classification/backbones/efficientnet.py b/src/otx/algo/classification/backbones/efficientnet.py index 55646d434bd..a7081728590 100644 --- a/src/otx/algo/classification/backbones/efficientnet.py +++ b/src/otx/algo/classification/backbones/efficientnet.py @@ -679,7 +679,7 @@ def init_weights(self, pretrained: bool | str | None = None) -> None: checkpoint = torch.load(pretrained, None) load_checkpoint_to_model(self, checkpoint) print(f"init weight - {pretrained}") - elif pretrained is not None: + elif pretrained: cache_dir = Path.home() / ".cache" / "torch" / "hub" / "checkpoints" download_model(net=self, model_name=self.model_name, local_model_store_dir_path=str(cache_dir)) - print(f"init weight - {pretrained_urls[self.model_name]}") + print(f"Download model weight in {cache_dir!s}") diff --git a/src/otx/algo/classification/dino_v2.py b/src/otx/algo/classification/dino_v2.py deleted file mode 100644 index fedbec588f5..00000000000 --- a/src/otx/algo/classification/dino_v2.py +++ /dev/null @@ -1,315 +0,0 @@ -# Copyright (C) 2023 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 -# -"""DINO-V2 model for the OTX classification.""" - -from __future__ import annotations - -import logging -import os -from pathlib import Path -from typing import TYPE_CHECKING, Any, Literal - -import torch -from torch import Tensor, nn - -from otx.algo.classification.classifier import SemiSLClassifier -from otx.algo.classification.heads import OTXSemiSLLinearClsHead -from otx.algo.classification.utils import get_classification_layers -from otx.algo.utils.utils import torch_hub_load -from otx.core.data.entity.base import OTXBatchLossEntity -from otx.core.data.entity.classification import ( - MulticlassClsBatchDataEntity, - MulticlassClsBatchPredEntity, -) -from otx.core.exporter.base import OTXModelExporter -from otx.core.exporter.native import OTXNativeModelExporter -from otx.core.metrics.accuracy import MultiClassClsMetricCallable -from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable -from otx.core.model.classification import OTXMulticlassClsModel -from otx.core.schedulers import LRSchedulerListCallable -from otx.core.types.label import LabelInfoTypes -from otx.utils.utils import get_class_initial_arguments - -if TYPE_CHECKING: - from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable - from typing_extensions import Self - - from otx.core.metrics import MetricCallable - - -# TODO(harimkang): Add more types of DINOv2 models. https://github.com/facebookresearch/dinov2/blob/main/MODEL_CARD.md -DINO_BACKBONE_TYPE = Literal["dinov2_vits14_reg"] - -logger = logging.getLogger() - - -class DINOv2(nn.Module): - """DINO-v2 Model.""" - - def __init__( - self, - backbone: DINO_BACKBONE_TYPE, - freeze_backbone: bool, - head_in_channels: int, - num_classes: int, - ): - super().__init__() - self._init_args = get_class_initial_arguments() - - ci_data_root = os.environ.get("CI_DATA_ROOT") - pretrained: bool = True - if ci_data_root is not None and Path(ci_data_root).exists(): - pretrained = False - - self.backbone = torch.hub.load( - repo_or_dir="facebookresearch/dinov2", - model=backbone, - pretrained=pretrained, - ) - - if ci_data_root is not None and Path(ci_data_root).exists(): - ckpt_filename = f"{backbone}4_pretrain.pth" - ckpt_path = Path(ci_data_root) / "torch" / "hub" / "checkpoints" / ckpt_filename - if not ckpt_path.exists(): - msg = ( - f"Internal cache was specified but cannot find weights file: {ckpt_filename}. load from torch hub." - ) - logger.warning(msg) - self.backbone = torch.hub.load(repo_or_dir="facebookresearch/dinov2", model=backbone, pretrained=True) - self.backbone.load_state_dict(torch.load(ckpt_path)) - - if freeze_backbone: - self._freeze_backbone(self.backbone) - - self.head = nn.Linear( - head_in_channels, - num_classes, - ) - - self.loss = nn.CrossEntropyLoss() - self.softmax = nn.Softmax() - - def _freeze_backbone(self, backbone: nn.Module) -> None: - """Freeze the backbone.""" - for _, v in backbone.named_parameters(): - v.requires_grad = False - - def forward(self, imgs: torch.Tensor, labels: torch.Tensor | None = None, **kwargs) -> torch.Tensor: - """Forward function.""" - feats = self.backbone(imgs) - logits = self.head(feats) - if self.training: - return self.loss(logits, labels) - return self.softmax(logits) - - def __reduce__(self): - return (DINOv2, self._init_args) - - -class DINOv2RegisterClassifier(OTXMulticlassClsModel): - """DINO-v2 Classification Model with register.""" - - def __init__( - self, - label_info: LabelInfoTypes, - backbone: DINO_BACKBONE_TYPE = "dinov2_vits14_reg", - optimizer: OptimizerCallable = DefaultOptimizerCallable, - scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable, - metric: MetricCallable = MultiClassClsMetricCallable, - torch_compile: bool = False, - freeze_backbone: bool = False, - ) -> None: - self.backbone = backbone - self.freeze_backbone = freeze_backbone - - super().__init__( - label_info=label_info, - optimizer=optimizer, - scheduler=scheduler, - metric=metric, - torch_compile=torch_compile, - ) - - def _create_model(self) -> nn.Module: - # Get classification_layers for class-incr learning - sample_model_dict = self._build_model(num_classes=5).state_dict() - incremental_model_dict = self._build_model(num_classes=6).state_dict() - self.classification_layers = get_classification_layers( - sample_model_dict, - incremental_model_dict, - prefix="model.", - ) - - return self._build_model(num_classes=self.num_classes) - - def _build_model(self, num_classes: int) -> nn.Module: - """Create the model.""" - return DINOv2( - backbone=self.backbone, - freeze_backbone=self.freeze_backbone, - # TODO(harimkang): A feature should be added to allow in_channels to adjust based on the arch. - head_in_channels=384, - num_classes=num_classes, - ) - - def _customize_inputs(self, entity: MulticlassClsBatchDataEntity) -> dict[str, Any]: - """Customize the inputs for the model.""" - return { - "imgs": entity.stacked_images, - "labels": torch.cat(entity.labels), - "imgs_info": entity.imgs_info, - } - - def _customize_outputs( - self, - outputs: Any, # noqa: ANN401 - inputs: MulticlassClsBatchDataEntity, - ) -> MulticlassClsBatchPredEntity | OTXBatchLossEntity: - """Customize the outputs for the model.""" - if self.training: - if not isinstance(outputs, torch.Tensor): - raise TypeError(outputs) - - losses = OTXBatchLossEntity() - losses["loss"] = outputs - return losses - - max_pred_elements, max_pred_idxs = torch.max(outputs, dim=1) - pred_scores = max_pred_elements - pred_labels = max_pred_idxs - - scores = torch.unbind(pred_scores, dim=0) - labels = torch.unbind(pred_labels, dim=0) - - return MulticlassClsBatchPredEntity( - batch_size=pred_labels.shape[0], - images=inputs.images, - imgs_info=inputs.imgs_info, - scores=scores, - labels=labels, - ) - - @property - def _exporter(self) -> OTXModelExporter: - """Creates OTXModelExporter object that can export the model.""" - return OTXNativeModelExporter( - task_level_export_parameters=self._export_parameters, - input_size=(1, 3, 224, 224), - mean=(123.675, 116.28, 103.53), - std=(58.395, 57.12, 57.375), - resize_mode="standard", - pad_value=0, - swap_rgb=False, - via_onnx=False, - onnx_export_configuration=None, - output_names=["logits", "feature_vector", "saliency_map"] if self.explain_mode else None, - ) - - @property - def _optimization_config(self) -> dict[str, Any]: - """PTQ config for DinoV2Cls.""" - return {"model_type": "transformer"} - - def forward_for_tracing(self, image: Tensor) -> Tensor | dict[str, Tensor]: - """Model forward function used for the model tracing during model exportation.""" - return self.model(image) - - def to(self, *args, **kwargs) -> Self: - """Return a model with specified device.""" - ret = super().to(*args, **kwargs) - if self.device.type == "xpu": - msg = f"{type(self).__name__} doesn't support XPU." - raise RuntimeError(msg) - return ret - - -class DINOv2ForMulticlassClsSemiSL(DINOv2RegisterClassifier): - """DinoV2 model for multiclass classification with semi-supervised learning. - - This class extends the `DINOv2RegisterClassifier` class and adds support for semi-supervised learning. - It overrides the `_build_model` and `_customize_inputs` methods to incorporate the semi-supervised learning. - - Args: - DINOv2RegisterClassifier (class): The base class for DinoV2 multiclass classification. - """ - - def _build_model(self, num_classes: int) -> nn.Module: - backbone = torch_hub_load( - repo_or_dir="facebookresearch/dinov2", - model=self.backbone, - ) - if self.freeze_backbone: - for _, v in backbone.named_parameters(): - v.requires_grad = False - return SemiSLClassifier( - backbone=backbone, - neck=None, - head=OTXSemiSLLinearClsHead( - num_classes=num_classes, - in_channels=384, - loss=nn.CrossEntropyLoss(reduction="none"), - ), - ) - - def _customize_inputs(self, inputs: MulticlassClsBatchDataEntity) -> dict[str, Any]: - """Customizes the input data for the model based on the current mode. - - Args: - inputs (MulticlassClsBatchDataEntity): The input batch of data. - - Returns: - dict[str, Any]: The customized input data. - """ - if self.training: - mode = "loss" - elif self.explain_mode: - mode = "explain" - else: - mode = "predict" - - if isinstance(inputs, dict): - # When used with an unlabeled dataset, it comes in as a dict. - images = {key: inputs[key].images for key in inputs} - labels = {key: torch.cat(inputs[key].labels, dim=0) for key in inputs} - imgs_info = {key: inputs[key].imgs_info for key in inputs} - return { - "images": images, - "labels": labels, - "imgs_info": imgs_info, - "mode": mode, - } - return { - "images": inputs.stacked_images, - "labels": torch.cat(inputs.labels, dim=0), - "imgs_info": inputs.imgs_info, - "mode": mode, - } - - def training_step(self, batch: MulticlassClsBatchDataEntity, batch_idx: int) -> Tensor: - """Performs a single training step on a batch of data. - - Args: - batch (MulticlassClsBatchDataEntity): The input batch of data. - batch_idx (int): The index of the current batch. - - Returns: - Tensor: The computed loss for the training step. - """ - loss = super().training_step(batch, batch_idx) - # Collect metrics related to Semi-SL Training. - self.log( - "train/unlabeled_coef", - self.model.head.unlabeled_coef, - on_step=True, - on_epoch=False, - prog_bar=True, - ) - self.log( - "train/num_pseudo_label", - self.model.head.num_pseudo_label, - on_step=True, - on_epoch=False, - prog_bar=True, - ) - return loss diff --git a/src/otx/algo/classification/efficientnet.py b/src/otx/algo/classification/efficientnet.py index bba5dafd61b..ce036ae0aa4 100644 --- a/src/otx/algo/classification/efficientnet.py +++ b/src/otx/algo/classification/efficientnet.py @@ -7,9 +7,8 @@ from __future__ import annotations from copy import deepcopy -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Literal -import torch from torch import Tensor, nn from otx.algo.classification.backbones.efficientnet import EFFICIENTNET_VERSION, OTXEfficientNet @@ -25,7 +24,6 @@ from otx.algo.classification.necks.gap import GlobalAveragePooling from otx.algo.classification.utils import get_classification_layers from otx.algo.utils.support_otx_v1 import OTXv1Helper -from otx.core.data.entity.base import OTXBatchLossEntity from otx.core.data.entity.classification import ( HlabelClsBatchDataEntity, HlabelClsBatchPredEntity, @@ -34,14 +32,12 @@ MultilabelClsBatchDataEntity, MultilabelClsBatchPredEntity, ) -from otx.core.exporter.base import OTXModelExporter -from otx.core.exporter.native import OTXNativeModelExporter -from otx.core.metrics import MetricInput from otx.core.metrics.accuracy import HLabelClsMetricCallble, MultiClassClsMetricCallable, MultiLabelClsMetricCallable from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable from otx.core.model.classification import OTXHlabelClsModel, OTXMulticlassClsModel, OTXMultilabelClsModel from otx.core.schedulers import LRSchedulerListCallable from otx.core.types.label import HLabelInfo, LabelInfoTypes +from otx.core.types.task import OTXTrainType if TYPE_CHECKING: from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable @@ -56,12 +52,15 @@ def __init__( self, label_info: LabelInfoTypes, version: EFFICIENTNET_VERSION = "b0", + pretrained: bool = True, optimizer: OptimizerCallable = DefaultOptimizerCallable, scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable, metric: MetricCallable = MultiClassClsMetricCallable, torch_compile: bool = False, + train_type: Literal[OTXTrainType.SUPERVISED, OTXTrainType.SEMI_SUPERVISED] = OTXTrainType.SUPERVISED, ) -> None: self.version = version + self.pretrained = pretrained super().__init__( label_info=label_info, @@ -69,6 +68,7 @@ def __init__( scheduler=scheduler, metric=metric, torch_compile=torch_compile, + train_type=train_type, ) def _create_model(self) -> nn.Module: @@ -86,14 +86,28 @@ def _create_model(self) -> nn.Module: return model def _build_model(self, num_classes: int) -> nn.Module: + backbone = OTXEfficientNet(version=self.version, pretrained=self.pretrained) + neck = GlobalAveragePooling(dim=2) + loss = nn.CrossEntropyLoss(reduction="none") + if self.train_type == OTXTrainType.SEMI_SUPERVISED: + return SemiSLClassifier( + backbone=backbone, + neck=neck, + head=OTXSemiSLLinearClsHead( + num_classes=num_classes, + in_channels=backbone.num_features, + loss=loss, + ), + ) + return ImageClassifier( - backbone=OTXEfficientNet(version=self.version, pretrained=True), - neck=GlobalAveragePooling(dim=2), + backbone=backbone, + neck=neck, head=LinearClsHead( num_classes=num_classes, - in_channels=1280, + in_channels=backbone.num_features, topk=(1, 5) if num_classes >= 5 else (1,), - loss=nn.CrossEntropyLoss(reduction="none"), + loss=loss, ), ) @@ -101,61 +115,6 @@ def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.") -> """Load the previous OTX ckpt according to OTX2.0.""" return OTXv1Helper.load_cls_effnet_b0_ckpt(state_dict, "multiclass", add_prefix) - def _reset_prediction_layer(self, num_classes: int) -> None: - return - - def _customize_inputs(self, inputs: MulticlassClsBatchDataEntity) -> dict[str, Any]: - if self.training: - mode = "loss" - elif self.explain_mode: - mode = "explain" - else: - mode = "predict" - - return { - "images": inputs.stacked_images, - "labels": torch.cat(inputs.labels, dim=0), - "imgs_info": inputs.imgs_info, - "mode": mode, - } - - def _customize_outputs( - self, - outputs: Any, # noqa: ANN401 - inputs: MulticlassClsBatchDataEntity, - ) -> MulticlassClsBatchPredEntity | OTXBatchLossEntity: - if self.training: - return OTXBatchLossEntity(loss=outputs) - - # To list, batch-wise - logits = outputs if isinstance(outputs, torch.Tensor) else outputs["logits"] - scores = torch.unbind(logits, 0) - preds = logits.argmax(-1, keepdim=True).unbind(0) - - return MulticlassClsBatchPredEntity( - batch_size=inputs.batch_size, - images=inputs.stacked_images, - imgs_info=inputs.imgs_info, - scores=scores, - labels=preds, - ) - - @property - def _exporter(self) -> OTXModelExporter: - """Creates OTXModelExporter object that can export the model.""" - return OTXNativeModelExporter( - task_level_export_parameters=self._export_parameters, - input_size=self.image_size, - mean=(123.675, 116.28, 103.53), - std=(58.395, 57.12, 57.375), - resize_mode="standard", - pad_value=0, - swap_rgb=False, - via_onnx=False, - onnx_export_configuration=None, - output_names=["logits", "feature_vector", "saliency_map"] if self.explain_mode else None, - ) - def forward_explain(self, inputs: MulticlassClsBatchDataEntity) -> MulticlassClsBatchPredEntity: """Model forward explain function.""" outputs = self.model(images=inputs.stacked_images, mode="explain") @@ -178,93 +137,6 @@ def forward_for_tracing(self, image: Tensor) -> Tensor | dict[str, Tensor]: return self.model(images=image, mode="tensor") -class EfficientNetForMulticlassClsSemiSL(EfficientNetForMulticlassCls): - """EfficientNet model for multiclass classification with semi-supervised learning. - - This class extends the `EfficientNetForMulticlassCls` class and adds support for semi-supervised learning. - It overrides the `_build_model` and `_customize_inputs` methods to incorporate the semi-supervised learning. - - Args: - EfficientNetForMulticlassCls (class): The base class for EfficientNet multiclass classification. - - Attributes: - version (str): The version of the EfficientNet model. - """ - - def _build_model(self, num_classes: int) -> nn.Module: - return SemiSLClassifier( - backbone=OTXEfficientNet(version=self.version, pretrained=True), - neck=GlobalAveragePooling(dim=2), - head=OTXSemiSLLinearClsHead( - num_classes=num_classes, - in_channels=1280, - loss=nn.CrossEntropyLoss(reduction="none"), - ), - ) - - def _customize_inputs(self, inputs: MulticlassClsBatchDataEntity) -> dict[str, Any]: - """Customizes the input data for the model based on the current mode. - - Args: - inputs (MulticlassClsBatchDataEntity): The input batch of data. - - Returns: - dict[str, Any]: The customized input data. - """ - if self.training: - mode = "loss" - elif self.explain_mode: - mode = "explain" - else: - mode = "predict" - - if isinstance(inputs, dict): - # When used with an unlabeled dataset, it comes in as a dict. - images = {key: inputs[key].images for key in inputs} - labels = {key: torch.cat(inputs[key].labels, dim=0) for key in inputs} - imgs_info = {key: inputs[key].imgs_info for key in inputs} - return { - "images": images, - "labels": labels, - "imgs_info": imgs_info, - "mode": mode, - } - return { - "images": inputs.stacked_images, - "labels": torch.cat(inputs.labels, dim=0), - "imgs_info": inputs.imgs_info, - "mode": mode, - } - - def training_step(self, batch: MulticlassClsBatchDataEntity, batch_idx: int) -> Tensor: - """Performs a single training step on a batch of data. - - Args: - batch (MulticlassClsBatchDataEntity): The input batch of data. - batch_idx (int): The index of the current batch. - - Returns: - Tensor: The computed loss for the training step. - """ - loss = super().training_step(batch, batch_idx) - # Collect metrics related to Semi-SL Training. - self.log( - "train/unlabeled_coef", - self.model.head.unlabeled_coef, - on_step=True, - on_epoch=False, - prog_bar=True, - ) - self.log( - "train/num_pseudo_label", - self.model.head.num_pseudo_label, - on_step=True, - on_epoch=False, - prog_bar=True, - ) - return loss - - class EfficientNetForMultilabelCls(OTXMultilabelClsModel): """EfficientNet Model for multi-label classification task.""" @@ -272,12 +144,14 @@ def __init__( self, label_info: LabelInfoTypes, version: EFFICIENTNET_VERSION = "b0", + pretrained: bool = True, optimizer: OptimizerCallable = DefaultOptimizerCallable, scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable, metric: MetricCallable = MultiLabelClsMetricCallable, torch_compile: bool = False, ) -> None: self.version = version + self.pretrained = pretrained super().__init__( label_info=label_info, @@ -302,12 +176,13 @@ def _create_model(self) -> nn.Module: return model def _build_model(self, num_classes: int) -> nn.Module: + backbone = OTXEfficientNet(version=self.version, pretrained=self.pretrained) return ImageClassifier( - backbone=OTXEfficientNet(version=self.version, pretrained=True), + backbone=backbone, neck=GlobalAveragePooling(dim=2), head=MultiLabelLinearClsHead( num_classes=num_classes, - in_channels=1280, + in_channels=backbone.num_features, scale=7.0, normalized=True, loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"), @@ -318,57 +193,6 @@ def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.") -> """Load the previous OTX ckpt according to OTX2.0.""" return OTXv1Helper.load_cls_effnet_b0_ckpt(state_dict, "multilabel", add_prefix) - def _customize_inputs(self, inputs: MultilabelClsBatchDataEntity) -> dict[str, Any]: - if self.training: - mode = "loss" - elif self.explain_mode: - mode = "explain" - else: - mode = "predict" - - return { - "images": inputs.stacked_images, - "labels": torch.stack(inputs.labels), - "imgs_info": inputs.imgs_info, - "mode": mode, - } - - def _customize_outputs( - self, - outputs: Any, # noqa: ANN401 - inputs: MultilabelClsBatchDataEntity, - ) -> MultilabelClsBatchPredEntity | OTXBatchLossEntity: - if self.training: - return OTXBatchLossEntity(loss=outputs) - - # To list, batch-wise - logits = outputs if isinstance(outputs, torch.Tensor) else outputs["logits"] - scores = torch.unbind(logits, 0) - - return MultilabelClsBatchPredEntity( - batch_size=inputs.batch_size, - images=inputs.images, - imgs_info=inputs.imgs_info, - scores=scores, - labels=logits.argmax(-1, keepdim=True).unbind(0), - ) - - @property - def _exporter(self) -> OTXModelExporter: - """Creates OTXModelExporter object that can export the model.""" - return OTXNativeModelExporter( - task_level_export_parameters=self._export_parameters, - input_size=(1, 3, 224, 224), - mean=(123.675, 116.28, 103.53), - std=(58.395, 57.12, 57.375), - resize_mode="standard", - pad_value=0, - swap_rgb=False, - via_onnx=False, - onnx_export_configuration=None, - output_names=["logits", "feature_vector", "saliency_map"] if self.explain_mode else None, - ) - def forward_explain(self, inputs: MultilabelClsBatchDataEntity) -> MultilabelClsBatchPredEntity: """Model forward explain function.""" outputs = self.model(images=inputs.stacked_images, mode="explain") @@ -400,12 +224,14 @@ def __init__( self, label_info: HLabelInfo, version: EFFICIENTNET_VERSION = "b0", + pretrained: bool = True, optimizer: OptimizerCallable = DefaultOptimizerCallable, scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable, metric: MetricCallable = HLabelClsMetricCallble, torch_compile: bool = False, ) -> None: self.version = version + self.pretrained = pretrained super().__init__( label_info=label_info, @@ -436,11 +262,12 @@ def _build_model(self, head_config: dict) -> nn.Module: if not isinstance(self.label_info, HLabelInfo): raise TypeError(self.label_info) + backbone = OTXEfficientNet(version=self.version, pretrained=self.pretrained) return ImageClassifier( - backbone=OTXEfficientNet(version=self.version, pretrained=True), + backbone=backbone, neck=GlobalAveragePooling(dim=2), head=HierarchicalLinearClsHead( - in_channels=1280, + in_channels=backbone.num_features, multiclass_loss=nn.CrossEntropyLoss(), multilabel_loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"), **head_config, @@ -451,81 +278,6 @@ def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.") -> """Load the previous OTX ckpt according to OTX2.0.""" return OTXv1Helper.load_cls_effnet_b0_ckpt(state_dict, "hlabel", add_prefix) - def _customize_inputs(self, inputs: HlabelClsBatchDataEntity) -> dict[str, Any]: - if self.training: - mode = "loss" - elif self.explain_mode: - mode = "explain" - else: - mode = "predict" - - return { - "images": inputs.stacked_images, - "labels": torch.stack(inputs.labels), - "imgs_info": inputs.imgs_info, - "mode": mode, - } - - def _customize_outputs( - self, - outputs: Any, # noqa: ANN401 - inputs: HlabelClsBatchDataEntity, - ) -> HlabelClsBatchPredEntity | OTXBatchLossEntity: - if self.training: - return OTXBatchLossEntity(loss=outputs) - - # To list, batch-wise - if isinstance(outputs, dict): - scores = outputs["scores"] - labels = outputs["labels"] - else: - scores = outputs - labels = outputs.argmax(-1, keepdim=True) - - return HlabelClsBatchPredEntity( - batch_size=inputs.batch_size, - images=inputs.images, - imgs_info=inputs.imgs_info, - scores=scores, - labels=labels, - ) - - def _convert_pred_entity_to_compute_metric( - self, - preds: HlabelClsBatchPredEntity, - inputs: HlabelClsBatchDataEntity, - ) -> MetricInput: - hlabel_info: HLabelInfo = self.label_info # type: ignore[assignment] - - _labels = torch.stack(preds.labels) if isinstance(preds.labels, list) else preds.labels - _scores = torch.stack(preds.scores) if isinstance(preds.scores, list) else preds.scores - if hlabel_info.num_multilabel_classes > 0: - preds_multiclass = _labels[:, : hlabel_info.num_multiclass_heads] - preds_multilabel = _scores[:, hlabel_info.num_multiclass_heads :] - pred_result = torch.cat([preds_multiclass, preds_multilabel], dim=1) - else: - pred_result = _labels - return { - "preds": pred_result, - "target": torch.stack(inputs.labels), - } - - @property - def _exporter(self) -> OTXModelExporter: - """Creates OTXModelExporter object that can export the model.""" - return OTXNativeModelExporter( - task_level_export_parameters=self._export_parameters, - input_size=(1, 3, 224, 224), - mean=(123.675, 116.28, 103.53), - std=(58.395, 57.12, 57.375), - resize_mode="standard", - pad_value=0, - swap_rgb=False, - via_onnx=False, - onnx_export_configuration=None, - output_names=["logits", "feature_vector", "saliency_map"] if self.explain_mode else None, - ) - def forward_explain(self, inputs: HlabelClsBatchDataEntity) -> HlabelClsBatchPredEntity: """Model forward explain function.""" outputs = self.model(images=inputs.stacked_images, mode="explain") diff --git a/src/otx/algo/classification/efficientnet_v2.py b/src/otx/algo/classification/efficientnet_v2.py deleted file mode 100644 index cc88cc3281b..00000000000 --- a/src/otx/algo/classification/efficientnet_v2.py +++ /dev/null @@ -1,531 +0,0 @@ -# Copyright (C) 2023 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 -# -"""EfficientNetV2 model implementation.""" -from __future__ import annotations - -from copy import deepcopy -from typing import TYPE_CHECKING, Any - -import torch -from torch import Tensor, nn - -from otx.algo.classification.backbones.timm import TimmBackbone -from otx.algo.classification.classifier import ImageClassifier, SemiSLClassifier -from otx.algo.classification.heads import ( - HierarchicalLinearClsHead, - LinearClsHead, - MultiLabelLinearClsHead, - OTXSemiSLLinearClsHead, -) -from otx.algo.classification.losses.asymmetric_angular_loss_with_ignore import AsymmetricAngularLossWithIgnore -from otx.algo.classification.necks.gap import GlobalAveragePooling -from otx.algo.classification.utils import get_classification_layers -from otx.algo.utils.support_otx_v1 import OTXv1Helper -from otx.core.data.entity.base import OTXBatchLossEntity -from otx.core.data.entity.classification import ( - HlabelClsBatchDataEntity, - HlabelClsBatchPredEntity, - MulticlassClsBatchDataEntity, - MulticlassClsBatchPredEntity, - MultilabelClsBatchDataEntity, - MultilabelClsBatchPredEntity, -) -from otx.core.exporter.base import OTXModelExporter -from otx.core.exporter.native import OTXNativeModelExporter -from otx.core.metrics import MetricInput -from otx.core.metrics.accuracy import HLabelClsMetricCallble, MultiClassClsMetricCallable, MultiLabelClsMetricCallable -from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable -from otx.core.model.classification import ( - OTXHlabelClsModel, - OTXMulticlassClsModel, - OTXMultilabelClsModel, -) -from otx.core.schedulers import LRSchedulerListCallable -from otx.core.types.label import HLabelInfo, LabelInfoTypes - -if TYPE_CHECKING: - from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable - - from otx.core.metrics import MetricCallable - - -class EfficientNetV2ForMulticlassCls(OTXMulticlassClsModel): - """EfficientNetV2 Model for multi-class classification task.""" - - def __init__( - self, - label_info: LabelInfoTypes, - optimizer: OptimizerCallable = DefaultOptimizerCallable, - scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable, - metric: MetricCallable = MultiClassClsMetricCallable, - torch_compile: bool = False, - ) -> None: - super().__init__( - label_info=label_info, - optimizer=optimizer, - scheduler=scheduler, - metric=metric, - torch_compile=torch_compile, - ) - - def _create_model(self) -> nn.Module: - # Get classification_layers for class-incr learning - sample_model_dict = self._build_model(num_classes=5).state_dict() - incremental_model_dict = self._build_model(num_classes=6).state_dict() - self.classification_layers = get_classification_layers( - sample_model_dict, - incremental_model_dict, - prefix="model.", - ) - - model = self._build_model(num_classes=self.num_classes) - model.init_weights() - return model - - def _build_model(self, num_classes: int) -> nn.Module: - return ImageClassifier( - backbone=TimmBackbone(backbone="efficientnetv2_s_21k", pretrained=True), - neck=GlobalAveragePooling(dim=2), - head=LinearClsHead( - num_classes=num_classes, - in_channels=1280, - topk=(1, 5) if num_classes >= 5 else (1,), - loss=nn.CrossEntropyLoss(reduction="none"), - ), - ) - - def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.") -> dict: - """Load the previous OTX ckpt according to OTX2.0.""" - return OTXv1Helper.load_cls_effnet_v2_ckpt(state_dict, "multiclass", add_prefix) - - def _customize_inputs(self, inputs: MulticlassClsBatchDataEntity) -> dict[str, Any]: - if self.training: - mode = "loss" - elif self.explain_mode: - mode = "explain" - else: - mode = "predict" - - return { - "images": inputs.stacked_images, - "labels": torch.cat(inputs.labels, dim=0), - "imgs_info": inputs.imgs_info, - "mode": mode, - } - - def _customize_outputs( - self, - outputs: Any, # noqa: ANN401 - inputs: MulticlassClsBatchDataEntity, - ) -> MulticlassClsBatchPredEntity | OTXBatchLossEntity: - if self.training: - return OTXBatchLossEntity(loss=outputs) - - # To list, batch-wise - logits = outputs if isinstance(outputs, torch.Tensor) else outputs["logits"] - scores = torch.unbind(logits, 0) - preds = logits.argmax(-1, keepdim=True).unbind(0) - - return MulticlassClsBatchPredEntity( - batch_size=inputs.batch_size, - images=inputs.images, - imgs_info=inputs.imgs_info, - scores=scores, - labels=preds, - ) - - @property - def _exporter(self) -> OTXModelExporter: - """Creates OTXModelExporter object that can export the model.""" - return OTXNativeModelExporter( - task_level_export_parameters=self._export_parameters, - input_size=self.image_size, - mean=(123.675, 116.28, 103.53), - std=(58.395, 57.12, 57.375), - resize_mode="standard", - pad_value=0, - swap_rgb=False, - via_onnx=False, - onnx_export_configuration=None, - output_names=["logits", "feature_vector", "saliency_map"] if self.explain_mode else None, - ) - - def forward_explain(self, inputs: MulticlassClsBatchDataEntity) -> MulticlassClsBatchPredEntity: - """Model forward explain function.""" - outputs = self.model(images=inputs.stacked_images, mode="explain") - - return MulticlassClsBatchPredEntity( - batch_size=len(outputs["preds"]), - images=inputs.images, - imgs_info=inputs.imgs_info, - labels=outputs["preds"], - scores=outputs["scores"], - saliency_map=outputs["saliency_map"], - feature_vector=outputs["feature_vector"], - ) - - def forward_for_tracing(self, image: torch.Tensor) -> torch.Tensor | dict[str, torch.Tensor]: - """Model forward function used for the model tracing during model exportation.""" - if self.explain_mode: - return self.model(images=image, mode="explain") - - return self.model(images=image, mode="tensor") - - -class EfficientNetV2ForMulticlassClsSemiSL(EfficientNetV2ForMulticlassCls): - """EfficientNetV2 model for multiclass classification with semi-supervised learning. - - This class extends the `EfficientNetV2ForMulticlassCls` class and adds support for semi-supervised learning. - It overrides the `_build_model` and `_customize_inputs` methods to incorporate the semi-supervised learning. - - Args: - EfficientNetV2ForMulticlassCls (class): The base class for EfficientNetV2 multiclass classification. - """ - - def _build_model(self, num_classes: int) -> nn.Module: - return SemiSLClassifier( - backbone=TimmBackbone(backbone="efficientnetv2_s_21k", pretrained=True), - neck=GlobalAveragePooling(dim=2), - head=OTXSemiSLLinearClsHead( - num_classes=num_classes, - in_channels=1280, - loss=nn.CrossEntropyLoss(reduction="none"), - ), - ) - - def _customize_inputs(self, inputs: MulticlassClsBatchDataEntity) -> dict[str, Any]: - """Customizes the input data for the model based on the current mode. - - Args: - inputs (MulticlassClsBatchDataEntity): The input batch of data. - - Returns: - dict[str, Any]: The customized input data. - """ - if self.training: - mode = "loss" - elif self.explain_mode: - mode = "explain" - else: - mode = "predict" - - if isinstance(inputs, dict): - # When used with an unlabeled dataset, it comes in as a dict. - images = {key: inputs[key].images for key in inputs} - labels = {key: torch.cat(inputs[key].labels, dim=0) for key in inputs} - imgs_info = {key: inputs[key].imgs_info for key in inputs} - return { - "images": images, - "labels": labels, - "imgs_info": imgs_info, - "mode": mode, - } - return { - "images": inputs.stacked_images, - "labels": torch.cat(inputs.labels, dim=0), - "imgs_info": inputs.imgs_info, - "mode": mode, - } - - def training_step(self, batch: MulticlassClsBatchDataEntity, batch_idx: int) -> Tensor: - """Performs a single training step on a batch of data. - - Args: - batch (MulticlassClsBatchDataEntity): The input batch of data. - batch_idx (int): The index of the current batch. - - Returns: - Tensor: The computed loss for the training step. - """ - loss = super().training_step(batch, batch_idx) - # Collect metrics related to Semi-SL Training. - self.log( - "train/unlabeled_coef", - self.model.head.unlabeled_coef, - on_step=True, - on_epoch=False, - prog_bar=True, - ) - self.log( - "train/num_pseudo_label", - self.model.head.num_pseudo_label, - on_step=True, - on_epoch=False, - prog_bar=True, - ) - return loss - - -class EfficientNetV2ForMultilabelCls(OTXMultilabelClsModel): - """EfficientNetV2 Model for multi-label classification task.""" - - def __init__( - self, - label_info: LabelInfoTypes, - optimizer: OptimizerCallable = DefaultOptimizerCallable, - scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable, - metric: MetricCallable = MultiLabelClsMetricCallable, - torch_compile: bool = False, - ) -> None: - super().__init__( - label_info=label_info, - optimizer=optimizer, - scheduler=scheduler, - metric=metric, - torch_compile=torch_compile, - ) - - def _create_model(self) -> nn.Module: - # Get classification_layers for class-incr learning - sample_model_dict = self._build_model(num_classes=5).state_dict() - incremental_model_dict = self._build_model(num_classes=6).state_dict() - self.classification_layers = get_classification_layers( - sample_model_dict, - incremental_model_dict, - prefix="model.", - ) - - model = self._build_model(num_classes=self.num_classes) - model.init_weights() - return model - - def _build_model(self, num_classes: int) -> nn.Module: - return ImageClassifier( - backbone=TimmBackbone(backbone="efficientnetv2_s_21k", pretrained=True), - neck=GlobalAveragePooling(dim=2), - head=MultiLabelLinearClsHead( - num_classes=num_classes, - in_channels=1280, - normalized=True, - scale=7.0, - loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"), - ), - ) - - def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.") -> dict: - """Load the previous OTX ckpt according to OTX2.0.""" - return OTXv1Helper.load_cls_effnet_v2_ckpt(state_dict, "multilabel", add_prefix) - - def _customize_inputs(self, inputs: MultilabelClsBatchDataEntity) -> dict[str, Any]: - if self.training: - mode = "loss" - elif self.explain_mode: - mode = "explain" - else: - mode = "predict" - - return { - "images": inputs.stacked_images, - "labels": torch.stack(inputs.labels), - "imgs_info": inputs.imgs_info, - "mode": mode, - } - - def _customize_outputs( - self, - outputs: Any, # noqa: ANN401 - inputs: MultilabelClsBatchDataEntity, - ) -> MultilabelClsBatchPredEntity | OTXBatchLossEntity: - if self.training: - return OTXBatchLossEntity(loss=outputs) - - # To list, batch-wise - logits = outputs if isinstance(outputs, torch.Tensor) else outputs["logits"] - scores = torch.unbind(logits, 0) - - return MultilabelClsBatchPredEntity( - batch_size=inputs.batch_size, - images=inputs.images, - imgs_info=inputs.imgs_info, - scores=scores, - labels=logits.argmax(-1, keepdim=True).unbind(0), - ) - - @property - def _exporter(self) -> OTXModelExporter: - """Creates OTXModelExporter object that can export the model.""" - return OTXNativeModelExporter( - task_level_export_parameters=self._export_parameters, - input_size=(1, 3, 224, 224), - mean=(123.675, 116.28, 103.53), - std=(58.395, 57.12, 57.375), - resize_mode="standard", - pad_value=0, - swap_rgb=False, - via_onnx=False, - onnx_export_configuration=None, - output_names=["logits", "feature_vector", "saliency_map"] if self.explain_mode else None, - ) - - def forward_explain(self, inputs: MultilabelClsBatchDataEntity) -> MultilabelClsBatchPredEntity: - """Model forward explain function.""" - outputs = self.model(images=inputs.stacked_images, mode="explain") - - return MultilabelClsBatchPredEntity( - batch_size=len(outputs["preds"]), - images=inputs.images, - imgs_info=inputs.imgs_info, - labels=outputs["preds"], - scores=outputs["scores"], - saliency_map=outputs["saliency_map"], - feature_vector=outputs["feature_vector"], - ) - - def forward_for_tracing(self, image: torch.Tensor) -> torch.Tensor | dict[str, torch.Tensor]: - """Model forward function used for the model tracing during model exportation.""" - if self.explain_mode: - return self.model(images=image, mode="explain") - - return self.model(images=image, mode="tensor") - - -class EfficientNetV2ForHLabelCls(OTXHlabelClsModel): - """EfficientNetV2 Model for hierarchical label classification task.""" - - label_info: HLabelInfo - - def __init__( - self, - label_info: HLabelInfo, - optimizer: OptimizerCallable = DefaultOptimizerCallable, - scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable, - metric: MetricCallable = HLabelClsMetricCallble, - torch_compile: bool = False, - ) -> None: - super().__init__( - label_info=label_info, - optimizer=optimizer, - scheduler=scheduler, - metric=metric, - torch_compile=torch_compile, - ) - - def _create_model(self) -> nn.Module: - # Get classification_layers for class-incr learning - sample_config = deepcopy(self.label_info.as_head_config_dict()) - sample_config["num_classes"] = 5 - sample_model_dict = self._build_model(head_config=sample_config).state_dict() - sample_config["num_classes"] = 6 - incremental_model_dict = self._build_model(head_config=sample_config).state_dict() - self.classification_layers = get_classification_layers( - sample_model_dict, - incremental_model_dict, - prefix="model.", - ) - - model = self._build_model(head_config=self.label_info.as_head_config_dict()) - model.init_weights() - return model - - def _build_model(self, head_config: dict) -> nn.Module: - return ImageClassifier( - backbone=TimmBackbone(backbone="efficientnetv2_s_21k", pretrained=True), - neck=GlobalAveragePooling(dim=2), - head=HierarchicalLinearClsHead( - in_channels=1280, - multiclass_loss=nn.CrossEntropyLoss(), - multilabel_loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"), - **head_config, - ), - ) - - def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.") -> dict: - """Load the previous OTX ckpt according to OTX2.0.""" - return OTXv1Helper.load_cls_effnet_v2_ckpt(state_dict, "hlabel", add_prefix) - - def _customize_inputs(self, inputs: HlabelClsBatchDataEntity) -> dict[str, Any]: - if self.training: - mode = "loss" - elif self.explain_mode: - mode = "explain" - else: - mode = "predict" - - return { - "images": inputs.stacked_images, - "labels": torch.stack(inputs.labels), - "imgs_info": inputs.imgs_info, - "mode": mode, - } - - def _customize_outputs( - self, - outputs: Any, # noqa: ANN401 - inputs: HlabelClsBatchDataEntity, - ) -> HlabelClsBatchPredEntity | OTXBatchLossEntity: - if self.training: - return OTXBatchLossEntity(loss=outputs) - - # To list, batch-wise - if isinstance(outputs, dict): - scores = outputs["scores"] - labels = outputs["labels"] - else: - scores = outputs - labels = outputs.argmax(-1, keepdim=True) - - return HlabelClsBatchPredEntity( - batch_size=inputs.batch_size, - images=inputs.images, - imgs_info=inputs.imgs_info, - scores=scores, - labels=labels, - ) - - def _convert_pred_entity_to_compute_metric( - self, - preds: HlabelClsBatchPredEntity, - inputs: HlabelClsBatchDataEntity, - ) -> MetricInput: - hlabel_info: HLabelInfo = self.label_info # type: ignore[assignment] - - _labels = torch.stack(preds.labels) if isinstance(preds.labels, list) else preds.labels - _scores = torch.stack(preds.scores) if isinstance(preds.scores, list) else preds.scores - if hlabel_info.num_multilabel_classes > 0: - preds_multiclass = _labels[:, : hlabel_info.num_multiclass_heads] - preds_multilabel = _scores[:, hlabel_info.num_multiclass_heads :] - pred_result = torch.cat([preds_multiclass, preds_multilabel], dim=1) - else: - pred_result = _labels - return { - "preds": pred_result, - "target": torch.stack(inputs.labels), - } - - @property - def _exporter(self) -> OTXModelExporter: - """Creates OTXModelExporter object that can export the model.""" - return OTXNativeModelExporter( - task_level_export_parameters=self._export_parameters, - input_size=(1, 3, 224, 224), - mean=(123.675, 116.28, 103.53), - std=(58.395, 57.12, 57.375), - resize_mode="standard", - pad_value=0, - swap_rgb=False, - via_onnx=False, - onnx_export_configuration=None, - output_names=["logits", "feature_vector", "saliency_map"] if self.explain_mode else None, - ) - - def forward_explain(self, inputs: HlabelClsBatchDataEntity) -> HlabelClsBatchPredEntity: - """Model forward explain function.""" - outputs = self.model(images=inputs.stacked_images, mode="explain") - - return HlabelClsBatchPredEntity( - batch_size=len(outputs["preds"]), - images=inputs.images, - imgs_info=inputs.imgs_info, - labels=outputs["preds"], - scores=outputs["scores"], - saliency_map=outputs["saliency_map"], - feature_vector=outputs["feature_vector"], - ) - - def forward_for_tracing(self, image: torch.Tensor) -> torch.Tensor | dict[str, torch.Tensor]: - """Model forward function used for the model tracing during model exportation.""" - if self.explain_mode: - return self.model(images=image, mode="explain") - - return self.model(images=image, mode="tensor") diff --git a/src/otx/algo/classification/huggingface_model.py b/src/otx/algo/classification/huggingface_model.py index 906b65b4123..f5945bc12e1 100644 --- a/src/otx/algo/classification/huggingface_model.py +++ b/src/otx/algo/classification/huggingface_model.py @@ -16,8 +16,6 @@ MulticlassClsBatchDataEntity, MulticlassClsBatchPredEntity, ) -from otx.core.exporter.base import OTXModelExporter -from otx.core.exporter.native import OTXNativeModelExporter from otx.core.metrics.accuracy import MultiClassClsMetricCallable from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable from otx.core.model.classification import OTXMulticlassClsModel @@ -105,22 +103,6 @@ def _customize_outputs( labels=preds, ) - @property - def _exporter(self) -> OTXModelExporter: - """Creates OTXModelExporter object that can export the model.""" - return OTXNativeModelExporter( - task_level_export_parameters=self._export_parameters, - input_size=self.image_size, - mean=(123.675, 116.28, 103.53), - std=(58.395, 57.12, 57.375), - resize_mode="standard", - pad_value=0, - swap_rgb=False, - via_onnx=False, - onnx_export_configuration=None, - output_names=["logits", "feature_vector", "saliency_map"] if self.explain_mode else None, - ) - def forward_for_tracing(self, image: Tensor) -> Tensor | dict[str, Tensor]: """Model forward function used for the model tracing during model exportation.""" if self.explain_mode: diff --git a/src/otx/algo/classification/mobilenet_v3.py b/src/otx/algo/classification/mobilenet_v3.py index 70eeebff731..756b52721ae 100644 --- a/src/otx/algo/classification/mobilenet_v3.py +++ b/src/otx/algo/classification/mobilenet_v3.py @@ -41,6 +41,7 @@ from otx.core.model.classification import OTXHlabelClsModel, OTXMulticlassClsModel, OTXMultilabelClsModel from otx.core.schedulers import LRSchedulerListCallable from otx.core.types.label import HLabelInfo, LabelInfoTypes +from otx.core.types.task import OTXTrainType if TYPE_CHECKING: from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable @@ -71,14 +72,17 @@ def __init__( scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable, metric: MetricCallable = MultiClassClsMetricCallable, torch_compile: bool = False, + train_type: Literal[OTXTrainType.SUPERVISED, OTXTrainType.SEMI_SUPERVISED] = OTXTrainType.SUPERVISED, ) -> None: self.mode = mode + super().__init__( label_info=label_info, optimizer=optimizer, scheduler=scheduler, metric=metric, torch_compile=torch_compile, + train_type=train_type, ) def _create_model(self) -> nn.Module: @@ -96,14 +100,29 @@ def _create_model(self) -> nn.Module: return model def _build_model(self, num_classes: int) -> nn.Module: + backbone = OTXMobileNetV3(mode=self.mode) + neck = GlobalAveragePooling(dim=2) + loss = nn.CrossEntropyLoss(reduction="none") + in_channels = 960 if self.mode == "large" else 576 + if self.train_type == OTXTrainType.SEMI_SUPERVISED: + return SemiSLClassifier( + backbone=backbone, + neck=neck, + head=OTXSemiSLLinearClsHead( + num_classes=num_classes, + in_channels=in_channels, + loss=loss, + ), + ) + return ImageClassifier( - backbone=OTXMobileNetV3(mode=self.mode), - neck=GlobalAveragePooling(dim=2), + backbone=backbone, + neck=neck, head=LinearClsHead( num_classes=num_classes, - in_channels=960, + in_channels=in_channels, topk=(1, 5) if num_classes >= 5 else (1,), - loss=nn.CrossEntropyLoss(reduction="none"), + loss=loss, ), ) @@ -111,58 +130,6 @@ def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.") -> """Load the previous OTX ckpt according to OTX2.0.""" return OTXv1Helper.load_cls_mobilenet_v3_ckpt(state_dict, "multiclass", add_prefix) - def _customize_inputs(self, inputs: MulticlassClsBatchDataEntity) -> dict[str, Any]: - if self.training: - mode = "loss" - elif self.explain_mode: - mode = "explain" - else: - mode = "predict" - - return { - "images": inputs.stacked_images, - "labels": torch.cat(inputs.labels, dim=0), - "imgs_info": inputs.imgs_info, - "mode": mode, - } - - def _customize_outputs( - self, - outputs: Any, # noqa: ANN401 - inputs: MulticlassClsBatchDataEntity, - ) -> MulticlassClsBatchPredEntity | OTXBatchLossEntity: - if self.training: - return OTXBatchLossEntity(loss=outputs) - - # To list, batch-wise - logits = outputs if isinstance(outputs, torch.Tensor) else outputs["logits"] - scores = torch.unbind(logits, 0) - preds = logits.argmax(-1, keepdim=True).unbind(0) - - return MulticlassClsBatchPredEntity( - batch_size=inputs.batch_size, - images=inputs.images, - imgs_info=inputs.imgs_info, - scores=scores, - labels=preds, - ) - - @property - def _exporter(self) -> OTXModelExporter: - """Creates OTXModelExporter object that can export the model.""" - return OTXNativeModelExporter( - task_level_export_parameters=self._export_parameters, - input_size=self.image_size, - mean=(123.675, 116.28, 103.53), - std=(58.395, 57.12, 57.375), - resize_mode="standard", - pad_value=0, - swap_rgb=False, - via_onnx=False, - onnx_export_configuration=None, - output_names=["logits", "feature_vector", "saliency_map"] if self.explain_mode else None, - ) - def forward_explain(self, inputs: MulticlassClsBatchDataEntity) -> MulticlassClsBatchPredEntity: """Model forward explain function.""" outputs = self.model(images=inputs.stacked_images, mode="explain") @@ -185,93 +152,6 @@ def forward_for_tracing(self, image: Tensor) -> Tensor | dict[str, Tensor]: return self.model(images=image, mode="tensor") -class MobileNetV3ForMulticlassClsSemiSL(MobileNetV3ForMulticlassCls): - """MobileNetV3 model for multiclass classification with semi-supervised learning. - - This class extends the `MobileNetV3ForMulticlassCls` class and adds support for semi-supervised learning. - It overrides the `_build_model` and `_customize_inputs` methods to incorporate the semi-supervised learning. - - Args: - MobileNetV3ForMulticlassCls (class): The base class for MobileNetV3 multiclass classification. - - Attributes: - mode (str): The mode of the OTXMobileNetV3 model. - """ - - def _build_model(self, num_classes: int) -> nn.Module: - return SemiSLClassifier( - backbone=OTXMobileNetV3(mode=self.mode), - neck=GlobalAveragePooling(dim=2), - head=OTXSemiSLLinearClsHead( - num_classes=num_classes, - in_channels=960, - loss=nn.CrossEntropyLoss(reduction="none"), - ), - ) - - def _customize_inputs(self, inputs: MulticlassClsBatchDataEntity) -> dict[str, Any]: - """Customizes the input data for the model based on the current mode. - - Args: - inputs (MulticlassClsBatchDataEntity): The input batch of data. - - Returns: - dict[str, Any]: The customized input data. - """ - if self.training: - mode = "loss" - elif self.explain_mode: - mode = "explain" - else: - mode = "predict" - - if isinstance(inputs, dict): - # When used with an unlabeled dataset, it comes in as a dict. - images = {key: inputs[key].images for key in inputs} - labels = {key: torch.cat(inputs[key].labels, dim=0) for key in inputs} - imgs_info = {key: inputs[key].imgs_info for key in inputs} - return { - "images": images, - "labels": labels, - "imgs_info": imgs_info, - "mode": mode, - } - return { - "images": inputs.stacked_images, - "labels": torch.cat(inputs.labels, dim=0), - "imgs_info": inputs.imgs_info, - "mode": mode, - } - - def training_step(self, batch: MulticlassClsBatchDataEntity, batch_idx: int) -> Tensor: - """Performs a single training step on a batch of data. - - Args: - batch (MulticlassClsBatchDataEntity): The input batch of data. - batch_idx (int): The index of the current batch. - - Returns: - Tensor: The computed loss for the training step. - """ - loss = super().training_step(batch, batch_idx) - # Collect metrics related to Semi-SL Training. - self.log( - "train/unlabeled_coef", - self.model.head.unlabeled_coef, - on_step=True, - on_epoch=False, - prog_bar=True, - ) - self.log( - "train/num_pseudo_label", - self.model.head.num_pseudo_label, - on_step=True, - on_epoch=False, - prog_bar=True, - ) - return loss - - class MobileNetV3ForMultilabelCls(OTXMultilabelClsModel): """MobileNetV3 Model for multi-class classification task.""" diff --git a/src/otx/algo/classification/timm_model.py b/src/otx/algo/classification/timm_model.py new file mode 100644 index 00000000000..04eaf5ff396 --- /dev/null +++ b/src/otx/algo/classification/timm_model.py @@ -0,0 +1,299 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""EfficientNetV2 model implementation.""" +from __future__ import annotations + +from copy import deepcopy +from typing import TYPE_CHECKING, Literal + +import torch +from torch import nn + +from otx.algo.classification.backbones.timm import TimmBackbone, TimmModelType +from otx.algo.classification.classifier import ImageClassifier, SemiSLClassifier +from otx.algo.classification.heads import ( + HierarchicalLinearClsHead, + LinearClsHead, + MultiLabelLinearClsHead, + OTXSemiSLLinearClsHead, +) +from otx.algo.classification.losses.asymmetric_angular_loss_with_ignore import AsymmetricAngularLossWithIgnore +from otx.algo.classification.necks.gap import GlobalAveragePooling +from otx.algo.classification.utils import get_classification_layers +from otx.algo.utils.support_otx_v1 import OTXv1Helper +from otx.core.data.entity.classification import ( + HlabelClsBatchDataEntity, + HlabelClsBatchPredEntity, + MulticlassClsBatchDataEntity, + MulticlassClsBatchPredEntity, + MultilabelClsBatchDataEntity, + MultilabelClsBatchPredEntity, +) +from otx.core.metrics.accuracy import HLabelClsMetricCallble, MultiClassClsMetricCallable, MultiLabelClsMetricCallable +from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable +from otx.core.model.classification import ( + OTXHlabelClsModel, + OTXMulticlassClsModel, + OTXMultilabelClsModel, +) +from otx.core.schedulers import LRSchedulerListCallable +from otx.core.types.label import HLabelInfo, LabelInfoTypes +from otx.core.types.task import OTXTrainType + +if TYPE_CHECKING: + from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable + + from otx.core.metrics import MetricCallable + + +class TimmModelForMulticlassCls(OTXMulticlassClsModel): + """TimmModel for multi-class classification task.""" + + def __init__( + self, + label_info: LabelInfoTypes, + backbone: TimmModelType, + pretrained: bool = True, + optimizer: OptimizerCallable = DefaultOptimizerCallable, + scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable, + metric: MetricCallable = MultiClassClsMetricCallable, + torch_compile: bool = False, + train_type: Literal[OTXTrainType.SUPERVISED, OTXTrainType.SEMI_SUPERVISED] = OTXTrainType.SUPERVISED, + ) -> None: + self.backbone = backbone + self.pretrained = pretrained + + super().__init__( + label_info=label_info, + optimizer=optimizer, + scheduler=scheduler, + metric=metric, + torch_compile=torch_compile, + train_type=train_type, + ) + + def _create_model(self) -> nn.Module: + # Get classification_layers for class-incr learning + sample_model_dict = self._build_model(num_classes=5).state_dict() + incremental_model_dict = self._build_model(num_classes=6).state_dict() + self.classification_layers = get_classification_layers( + sample_model_dict, + incremental_model_dict, + prefix="model.", + ) + + model = self._build_model(num_classes=self.num_classes) + model.init_weights() + return model + + def _build_model(self, num_classes: int) -> nn.Module: + backbone = TimmBackbone(backbone=self.backbone, pretrained=self.pretrained) + neck = GlobalAveragePooling(dim=2) + loss = nn.CrossEntropyLoss(reduction="none") + if self.train_type == OTXTrainType.SEMI_SUPERVISED: + return SemiSLClassifier( + backbone=backbone, + neck=neck, + head=OTXSemiSLLinearClsHead( + num_classes=num_classes, + in_channels=backbone.num_features, + loss=loss, + ), + ) + + return ImageClassifier( + backbone=backbone, + neck=neck, + head=LinearClsHead( + num_classes=num_classes, + in_channels=backbone.num_features, + topk=(1, 5) if num_classes >= 5 else (1,), + loss=loss, + ), + ) + + def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.") -> dict: + """Load the previous OTX ckpt according to OTX2.0.""" + return OTXv1Helper.load_cls_effnet_v2_ckpt(state_dict, "multiclass", add_prefix) + + def forward_explain(self, inputs: MulticlassClsBatchDataEntity) -> MulticlassClsBatchPredEntity: + """Model forward explain function.""" + outputs = self.model(images=inputs.stacked_images, mode="explain") + + return MulticlassClsBatchPredEntity( + batch_size=len(outputs["preds"]), + images=inputs.images, + imgs_info=inputs.imgs_info, + labels=outputs["preds"], + scores=outputs["scores"], + saliency_map=outputs["saliency_map"], + feature_vector=outputs["feature_vector"], + ) + + def forward_for_tracing(self, image: torch.Tensor) -> torch.Tensor | dict[str, torch.Tensor]: + """Model forward function used for the model tracing during model exportation.""" + if self.explain_mode: + return self.model(images=image, mode="explain") + + return self.model(images=image, mode="tensor") + + +class TimmModelForMultilabelCls(OTXMultilabelClsModel): + """TimmModel for multi-label classification task.""" + + def __init__( + self, + label_info: LabelInfoTypes, + backbone: TimmModelType, + pretrained: bool = True, + optimizer: OptimizerCallable = DefaultOptimizerCallable, + scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable, + metric: MetricCallable = MultiLabelClsMetricCallable, + torch_compile: bool = False, + ) -> None: + self.backbone = backbone + self.pretrained = pretrained + + super().__init__( + label_info=label_info, + optimizer=optimizer, + scheduler=scheduler, + metric=metric, + torch_compile=torch_compile, + ) + + def _create_model(self) -> nn.Module: + # Get classification_layers for class-incr learning + sample_model_dict = self._build_model(num_classes=5).state_dict() + incremental_model_dict = self._build_model(num_classes=6).state_dict() + self.classification_layers = get_classification_layers( + sample_model_dict, + incremental_model_dict, + prefix="model.", + ) + + model = self._build_model(num_classes=self.num_classes) + model.init_weights() + return model + + def _build_model(self, num_classes: int) -> nn.Module: + backbone = TimmBackbone(backbone=self.backbone, pretrained=self.pretrained) + return ImageClassifier( + backbone=backbone, + neck=GlobalAveragePooling(dim=2), + head=MultiLabelLinearClsHead( + num_classes=num_classes, + in_channels=backbone.num_features, + normalized=True, + scale=7.0, + loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"), + ), + ) + + def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.") -> dict: + """Load the previous OTX ckpt according to OTX2.0.""" + return OTXv1Helper.load_cls_effnet_v2_ckpt(state_dict, "multilabel", add_prefix) + + def forward_explain(self, inputs: MultilabelClsBatchDataEntity) -> MultilabelClsBatchPredEntity: + """Model forward explain function.""" + outputs = self.model(images=inputs.stacked_images, mode="explain") + + return MultilabelClsBatchPredEntity( + batch_size=len(outputs["preds"]), + images=inputs.images, + imgs_info=inputs.imgs_info, + labels=outputs["preds"], + scores=outputs["scores"], + saliency_map=outputs["saliency_map"], + feature_vector=outputs["feature_vector"], + ) + + def forward_for_tracing(self, image: torch.Tensor) -> torch.Tensor | dict[str, torch.Tensor]: + """Model forward function used for the model tracing during model exportation.""" + if self.explain_mode: + return self.model(images=image, mode="explain") + + return self.model(images=image, mode="tensor") + + +class TimmModelForHLabelCls(OTXHlabelClsModel): + """EfficientNetV2 Model for hierarchical label classification task.""" + + label_info: HLabelInfo + + def __init__( + self, + label_info: HLabelInfo, + backbone: TimmModelType, + pretrained: bool = True, + optimizer: OptimizerCallable = DefaultOptimizerCallable, + scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable, + metric: MetricCallable = HLabelClsMetricCallble, + torch_compile: bool = False, + ) -> None: + self.backbone = backbone + self.pretrained = pretrained + + super().__init__( + label_info=label_info, + optimizer=optimizer, + scheduler=scheduler, + metric=metric, + torch_compile=torch_compile, + ) + + def _create_model(self) -> nn.Module: + # Get classification_layers for class-incr learning + sample_config = deepcopy(self.label_info.as_head_config_dict()) + sample_config["num_classes"] = 5 + sample_model_dict = self._build_model(head_config=sample_config).state_dict() + sample_config["num_classes"] = 6 + incremental_model_dict = self._build_model(head_config=sample_config).state_dict() + self.classification_layers = get_classification_layers( + sample_model_dict, + incremental_model_dict, + prefix="model.", + ) + + model = self._build_model(head_config=self.label_info.as_head_config_dict()) + model.init_weights() + return model + + def _build_model(self, head_config: dict) -> nn.Module: + backbone = TimmBackbone(backbone=self.backbone, pretrained=self.pretrained) + return ImageClassifier( + backbone=backbone, + neck=GlobalAveragePooling(dim=2), + head=HierarchicalLinearClsHead( + in_channels=backbone.num_features, + multiclass_loss=nn.CrossEntropyLoss(), + multilabel_loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"), + **head_config, + ), + ) + + def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.") -> dict: + """Load the previous OTX ckpt according to OTX2.0.""" + return OTXv1Helper.load_cls_effnet_v2_ckpt(state_dict, "hlabel", add_prefix) + + def forward_explain(self, inputs: HlabelClsBatchDataEntity) -> HlabelClsBatchPredEntity: + """Model forward explain function.""" + outputs = self.model(images=inputs.stacked_images, mode="explain") + + return HlabelClsBatchPredEntity( + batch_size=len(outputs["preds"]), + images=inputs.images, + imgs_info=inputs.imgs_info, + labels=outputs["preds"], + scores=outputs["scores"], + saliency_map=outputs["saliency_map"], + feature_vector=outputs["feature_vector"], + ) + + def forward_for_tracing(self, image: torch.Tensor) -> torch.Tensor | dict[str, torch.Tensor]: + """Model forward function used for the model tracing during model exportation.""" + if self.explain_mode: + return self.model(images=image, mode="explain") + + return self.model(images=image, mode="tensor") diff --git a/src/otx/algo/classification/vit.py b/src/otx/algo/classification/vit.py index a6d8a1a34c5..f2ccd09b8d9 100644 --- a/src/otx/algo/classification/vit.py +++ b/src/otx/algo/classification/vit.py @@ -7,12 +7,12 @@ import types from copy import deepcopy from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, Generic +from typing import TYPE_CHECKING, Any, Callable, Generic, Literal from urllib.parse import urlparse import numpy as np import torch -from torch import Tensor, nn +from torch import nn from torch.hub import download_url_to_file from otx.algo.classification.backbones.vision_transformer import VIT_ARCH_TYPE, VisionTransformer @@ -27,18 +27,7 @@ from otx.algo.classification.utils import get_classification_layers from otx.algo.explain.explain_algo import ViTReciproCAM, feature_vector_fn from otx.algo.utils.support_otx_v1 import OTXv1Helper -from otx.core.data.entity.base import OTXBatchLossEntity, T_OTXBatchDataEntity, T_OTXBatchPredEntity -from otx.core.data.entity.classification import ( - HlabelClsBatchDataEntity, - HlabelClsBatchPredEntity, - MulticlassClsBatchDataEntity, - MulticlassClsBatchPredEntity, - MultilabelClsBatchDataEntity, - MultilabelClsBatchPredEntity, -) -from otx.core.exporter.base import OTXModelExporter -from otx.core.exporter.native import OTXNativeModelExporter -from otx.core.metrics import MetricInput +from otx.core.data.entity.base import T_OTXBatchDataEntity, T_OTXBatchPredEntity from otx.core.metrics.accuracy import HLabelClsMetricCallble, MultiClassClsMetricCallable, MultiLabelClsMetricCallable from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable from otx.core.model.classification import ( @@ -48,6 +37,7 @@ ) from otx.core.schedulers import LRSchedulerListCallable from otx.core.types.label import HLabelInfo, LabelInfoTypes +from otx.core.types.task import OTXTrainType if TYPE_CHECKING: from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable @@ -229,16 +219,19 @@ def __init__( scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable, metric: MetricCallable = MultiClassClsMetricCallable, torch_compile: bool = False, + train_type: Literal[OTXTrainType.SUPERVISED, OTXTrainType.SEMI_SUPERVISED] = OTXTrainType.SUPERVISED, ) -> None: self.arch = arch self.lora = lora self.pretrained = pretrained + super().__init__( label_info=label_info, optimizer=optimizer, scheduler=scheduler, metric=metric, torch_compile=torch_compile, + train_type=train_type, ) def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.") -> dict: @@ -285,6 +278,18 @@ def _build_model(self, num_classes: int) -> nn.Module: {"bias": 0.0, "val": 1.0, "layer": "LayerNorm", "type": "Constant"}, ] vit_backbone = VisionTransformer(arch=self.arch, img_size=224, lora=self.lora) + if self.train_type == OTXTrainType.SEMI_SUPERVISED: + return SemiSLClassifier( + backbone=vit_backbone, + neck=None, + head=OTXSemiSLVisionTransformerClsHead( + num_classes=num_classes, + in_channels=vit_backbone.embed_dim, + loss=nn.CrossEntropyLoss(reduction="none"), + ), + init_cfg=init_cfg, + ) + return ImageClassifier( backbone=vit_backbone, neck=None, @@ -297,69 +302,6 @@ def _build_model(self, num_classes: int) -> nn.Module: init_cfg=init_cfg, ) - def _customize_inputs(self, inputs: MulticlassClsBatchDataEntity) -> dict[str, Any]: - if self.training: - mode = "loss" - elif self.explain_mode: - mode = "explain" - else: - mode = "predict" - - return { - "images": inputs.stacked_images, - "labels": torch.cat(inputs.labels, dim=0), - "imgs_info": inputs.imgs_info, - "mode": mode, - } - - def _customize_outputs( - self, - outputs: Any, # noqa: ANN401 - inputs: MulticlassClsBatchDataEntity, - ) -> MulticlassClsBatchPredEntity | OTXBatchLossEntity: - if self.training: - return OTXBatchLossEntity(loss=outputs) - - if self.explain_mode: - return MulticlassClsBatchPredEntity( - batch_size=inputs.batch_size, - images=inputs.images, - imgs_info=inputs.imgs_info, - scores=outputs["scores"], - labels=outputs["labels"], - saliency_map=outputs["saliency_map"], - feature_vector=outputs["feature_vector"], - ) - - # To list, batch-wise - logits = outputs if isinstance(outputs, torch.Tensor) else outputs["logits"] - scores = torch.unbind(logits, 0) - preds = logits.argmax(-1, keepdim=True).unbind(0) - - return MulticlassClsBatchPredEntity( - batch_size=inputs.batch_size, - images=inputs.images, - imgs_info=inputs.imgs_info, - scores=scores, - labels=preds, - ) - - @property - def _exporter(self) -> OTXModelExporter: - """Creates OTXModelExporter object that can export the model.""" - return OTXNativeModelExporter( - task_level_export_parameters=self._export_parameters, - input_size=self.image_size, - mean=(123.675, 116.28, 103.53), - std=(58.395, 57.12, 57.375), - resize_mode="standard", - pad_value=0, - swap_rgb=False, - via_onnx=False, - onnx_export_configuration=None, - output_names=["logits", "feature_vector", "saliency_map"] if self.explain_mode else None, - ) - class VisionTransformerForMulticlassClsSemiSL(VisionTransformerForMulticlassCls): """VisionTransformer model for multiclass classification with semi-supervised learning. @@ -388,68 +330,6 @@ def _build_model(self, num_classes: int) -> nn.Module: init_cfg=init_cfg, ) - def _customize_inputs(self, inputs: MulticlassClsBatchDataEntity) -> dict[str, Any]: - """Customizes the input data for the model based on the current mode. - - Args: - inputs (MulticlassClsBatchDataEntity): The input batch of data. - - Returns: - dict[str, Any]: The customized input data. - """ - if self.training: - mode = "loss" - elif self.explain_mode: - mode = "explain" - else: - mode = "predict" - - if isinstance(inputs, dict): - # When used with an unlabeled dataset, it comes in as a dict. - images = {key: inputs[key].images for key in inputs} - labels = {key: torch.cat(inputs[key].labels, dim=0) for key in inputs} - imgs_info = {key: inputs[key].imgs_info for key in inputs} - return { - "images": images, - "labels": labels, - "imgs_info": imgs_info, - "mode": mode, - } - return { - "images": inputs.stacked_images, - "labels": torch.cat(inputs.labels, dim=0), - "imgs_info": inputs.imgs_info, - "mode": mode, - } - - def training_step(self, batch: MulticlassClsBatchDataEntity, batch_idx: int) -> Tensor: - """Performs a single training step on a batch of data. - - Args: - batch (MulticlassClsBatchDataEntity): The input batch of data. - batch_idx (int): The index of the current batch. - - Returns: - Tensor: The computed loss for the training step. - """ - loss = super().training_step(batch, batch_idx) - # Collect metrics related to Semi-SL Training. - self.log( - "train/unlabeled_coef", - self.model.head.unlabeled_coef, - on_step=True, - on_epoch=False, - prog_bar=True, - ) - self.log( - "train/num_pseudo_label", - self.model.head.num_pseudo_label, - on_step=True, - on_epoch=False, - prog_bar=True, - ) - return loss - class VisionTransformerForMultilabelCls(ForwardExplainMixInForViT, OTXMultilabelClsModel): """DeitTiny Model for multi-class classification task.""" @@ -533,69 +413,6 @@ def _build_model(self, num_classes: int) -> nn.Module: init_cfg=init_cfg, ) - def _customize_inputs(self, inputs: MultilabelClsBatchDataEntity) -> dict[str, Any]: - if self.training: - mode = "loss" - elif self.explain_mode: - mode = "explain" - else: - mode = "predict" - - return { - "images": inputs.stacked_images, - "labels": torch.stack(inputs.labels), - "imgs_info": inputs.imgs_info, - "mode": mode, - } - - def _customize_outputs( - self, - outputs: Any, # noqa: ANN401 - inputs: MultilabelClsBatchDataEntity, - ) -> MultilabelClsBatchPredEntity | OTXBatchLossEntity: - if self.training: - return OTXBatchLossEntity(loss=outputs) - - if self.explain_mode: - return MultilabelClsBatchPredEntity( - batch_size=inputs.batch_size, - images=inputs.images, - imgs_info=inputs.imgs_info, - scores=outputs["scores"], - labels=outputs["labels"], - saliency_map=outputs["saliency_map"], - feature_vector=outputs["feature_vector"], - ) - - # To list, batch-wise - logits = outputs if isinstance(outputs, torch.Tensor) else outputs["logits"] - scores = torch.unbind(logits, 0) - preds = logits.argmax(-1, keepdim=True).unbind(0) - - return MultilabelClsBatchPredEntity( - batch_size=inputs.batch_size, - images=inputs.images, - imgs_info=inputs.imgs_info, - scores=scores, - labels=preds, - ) - - @property - def _exporter(self) -> OTXModelExporter: - """Creates OTXModelExporter object that can export the model.""" - return OTXNativeModelExporter( - task_level_export_parameters=self._export_parameters, - input_size=(1, 3, 224, 224), - mean=(123.675, 116.28, 103.53), - std=(58.395, 57.12, 57.375), - resize_mode="standard", - pad_value=0, - swap_rgb=False, - via_onnx=False, - onnx_export_configuration=None, - output_names=["logits", "feature_vector", "saliency_map"] if self.explain_mode else None, - ) - class VisionTransformerForHLabelCls(ForwardExplainMixInForViT, OTXHlabelClsModel): """DeitTiny Model for hierarchical label classification task.""" @@ -685,88 +502,3 @@ def _build_model(self, head_config: dict) -> nn.Module: ), init_cfg=init_cfg, ) - - def _customize_inputs(self, inputs: HlabelClsBatchDataEntity) -> dict[str, Any]: - if self.training: - mode = "loss" - elif self.explain_mode: - mode = "explain" - else: - mode = "predict" - - return { - "images": inputs.stacked_images, - "labels": torch.stack(inputs.labels), - "imgs_info": inputs.imgs_info, - "mode": mode, - } - - def _customize_outputs( - self, - outputs: Any, # noqa: ANN401 - inputs: HlabelClsBatchDataEntity, - ) -> HlabelClsBatchPredEntity | OTXBatchLossEntity: - if self.training: - return OTXBatchLossEntity(loss=outputs) - - if isinstance(outputs, dict): - scores = outputs["scores"] - labels = outputs["labels"] - else: - scores = outputs - labels = outputs.argmax(-1, keepdim=True) - - if self.explain_mode: - return HlabelClsBatchPredEntity( - batch_size=inputs.batch_size, - images=inputs.images, - imgs_info=inputs.imgs_info, - scores=scores, - labels=labels, - saliency_map=outputs["saliency_map"], - feature_vector=outputs["feature_vector"], - ) - - return HlabelClsBatchPredEntity( - batch_size=inputs.batch_size, - images=inputs.images, - imgs_info=inputs.imgs_info, - scores=scores, - labels=labels, - ) - - def _convert_pred_entity_to_compute_metric( - self, - preds: HlabelClsBatchPredEntity, - inputs: HlabelClsBatchDataEntity, - ) -> MetricInput: - hlabel_info: HLabelInfo = self.label_info # type: ignore[assignment] - - _labels = torch.stack(preds.labels) if isinstance(preds.labels, list) else preds.labels - _scores = torch.stack(preds.scores) if isinstance(preds.scores, list) else preds.scores - if hlabel_info.num_multilabel_classes > 0: - preds_multiclass = _labels[:, : hlabel_info.num_multiclass_heads] - preds_multilabel = _scores[:, hlabel_info.num_multiclass_heads :] - pred_result = torch.cat([preds_multiclass, preds_multilabel], dim=1) - else: - pred_result = _labels - return { - "preds": pred_result, - "target": torch.stack(inputs.labels), - } - - @property - def _exporter(self) -> OTXModelExporter: - """Creates OTXModelExporter object that can export the model.""" - return OTXNativeModelExporter( - task_level_export_parameters=self._export_parameters, - input_size=(1, 3, 224, 224), - mean=(123.675, 116.28, 103.53), - std=(58.395, 57.12, 57.375), - resize_mode="standard", - pad_value=0, - swap_rgb=False, - via_onnx=False, - onnx_export_configuration=None, - output_names=["logits", "feature_vector", "saliency_map"] if self.explain_mode else None, - ) diff --git a/src/otx/core/model/anomaly.py b/src/otx/core/model/anomaly.py index deb6607b823..1e3e9e0dd1f 100644 --- a/src/otx/core/model/anomaly.py +++ b/src/otx/core/model/anomaly.py @@ -25,6 +25,7 @@ ) from otx.core.data.entity.base import ImageInfo from otx.core.exporter.anomaly import OTXAnomalyModelExporter +from otx.core.model.base import OTXModel from otx.core.types.export import OTXExportFormatType from otx.core.types.precision import OTXPrecisionType from otx.core.types.task import OTXTaskType @@ -37,9 +38,8 @@ from lightning.pytorch import Trainer from lightning.pytorch.callbacks.callback import Callback from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable - from lightning.pytorch.utilities.types import STEP_OUTPUT from torchmetrics import Metric - +from otx.core.types.label import AnomalyLabelInfo AnomalyModelInputs: TypeAlias = ( AnomalyClassificationDataBatch | AnomalySegmentationDataBatch | AnomalyDetectionDataBatch @@ -49,10 +49,11 @@ ) -class OTXAnomaly: +class OTXAnomaly(OTXModel): """Methods used to make OTX model compatible with the Anomalib model.""" def __init__(self) -> None: + super().__init__(label_info=AnomalyLabelInfo()) self.optimizer: list[OptimizerCallable] | OptimizerCallable = None self.scheduler: list[LRSchedulerCallable] | LRSchedulerCallable = None self._input_size: tuple[int, int] = (256, 256) @@ -154,37 +155,6 @@ def configure_callbacks(self) -> list[Callback]: ), ] - def on_test_batch_end( - self, - outputs: dict, - batch: AnomalyModelInputs | dict, - batch_idx: int, - dataloader_idx: int = 0, - ) -> None: - """Called in the predict loop after the batch. - - Args: - outputs: The outputs of predict_step(x) - batch: The batched data as it is returned by the prediction DataLoader. - batch_idx: the index of the batch - dataloader_idx: the index of the dataloader - - """ - if not isinstance(batch, dict): - batch = self._customize_inputs(batch) - super().on_test_batch_end(outputs, batch, batch_idx, dataloader_idx) # type: ignore[misc] - - def predict_step( - self, - inputs: AnomalyModelInputs | dict, - batch_idx: int = 0, - **kwargs, - ) -> dict: - """Return predictions from the anomalib model.""" - if not isinstance(inputs, dict): - inputs = self._customize_inputs(inputs) - return super().predict_step(inputs, batch_idx, **kwargs) # type: ignore[misc] - def on_predict_batch_end( self, outputs: dict, @@ -203,46 +173,6 @@ def on_predict_batch_end( outputs.clear() outputs.update({"prediction": _outputs}) - def configure_optimizers(self) -> tuple[list[torch.optim.Optimizer], list[torch.optim.Optimizer]] | None: # type: ignore[override] - """Configure optimizers for Anomalib models. - - If the anomalib lightning model supports optimizers, return the optimizer. - If ``self.trainable_model`` is None then the model does not support training. - Else don't return optimizer even if it is configured in the OTX model. - """ - # [TODO](ashwinvaidya17): Revisit this method - if self.optimizer and self.trainable_model: - optimizer = self.optimizer - if isinstance(optimizer, list): - if len(optimizer) > 1: - msg = "Only one optimizer should be passed" - raise ValueError(msg) - optimizer = optimizer[0] - params = getattr(self.model, self.trainable_model).parameters() - return optimizer(params=params) - return super().configure_optimizers() # type: ignore[misc] - - def validation_step( - self, - inputs: AnomalyModelInputs, - batch_idx: int = 0, - ) -> STEP_OUTPUT: - """Call validation step of the anomalib model.""" - raise NotImplementedError - - def forward( - self, - inputs: AnomalyModelInputs, - ) -> AnomalyModelOutputs: - """Wrap forward method of the Anomalib model.""" - outputs = self.validation_step(inputs) - # TODO(Ashwin): update forward implementation to comply with other OTX models - _PostProcessorCallback._post_process(outputs) # noqa: SLF001 - _PostProcessorCallback._compute_scores_and_labels(self, outputs) # noqa: SLF001 - _MinMaxNormalizationCallback._normalize_batch(outputs, self) # noqa: SLF001 - - return self._customize_outputs(outputs=outputs, inputs=inputs) - def _customize_inputs( self, inputs: AnomalyModelInputs, diff --git a/src/otx/core/model/classification.py b/src/otx/core/model/classification.py index 2b66d17c7e4..7f9f6e63c62 100644 --- a/src/otx/core/model/classification.py +++ b/src/otx/core/model/classification.py @@ -4,12 +4,13 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Literal import numpy as np import torch from torch import Tensor +from otx.core.data.entity.base import OTXBatchLossEntity from otx.core.data.entity.classification import ( HlabelClsBatchDataEntity, HlabelClsBatchPredEntity, @@ -18,6 +19,8 @@ MultilabelClsBatchDataEntity, MultilabelClsBatchPredEntity, ) +from otx.core.exporter.base import OTXModelExporter +from otx.core.exporter.native import OTXNativeModelExporter from otx.core.metrics import MetricInput from otx.core.metrics.accuracy import ( HLabelClsMetricCallble, @@ -28,6 +31,7 @@ from otx.core.schedulers import LRSchedulerListCallable from otx.core.types.export import TaskLevelExportParameters from otx.core.types.label import HLabelInfo, LabelInfo, LabelInfoTypes +from otx.core.types.task import OTXTrainType if TYPE_CHECKING: from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable @@ -46,7 +50,10 @@ def __init__( scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable, metric: MetricCallable = MultiClassClsMetricCallable, torch_compile: bool = False, + train_type: Literal[OTXTrainType.SUPERVISED, OTXTrainType.SEMI_SUPERVISED] = OTXTrainType.SUPERVISED, ) -> None: + self.train_type = train_type + super().__init__( label_info=label_info, optimizer=optimizer, @@ -56,6 +63,88 @@ def __init__( ) self.image_size = (1, 3, 224, 224) + def _customize_inputs(self, inputs: MulticlassClsBatchDataEntity) -> dict[str, Any]: + if self.training: + mode = "loss" + elif self.explain_mode: + mode = "explain" + else: + mode = "predict" + + if self.train_type == OTXTrainType.SEMI_SUPERVISED and isinstance(inputs, dict): + # When used with an unlabeled dataset, it comes in as a dict. + images = {key: inputs[key].images for key in inputs} + labels = {key: torch.cat(inputs[key].labels, dim=0) for key in inputs} + imgs_info = {key: inputs[key].imgs_info for key in inputs} + return { + "images": images, + "labels": labels, + "imgs_info": imgs_info, + "mode": mode, + } + + return { + "images": inputs.stacked_images, + "labels": torch.cat(inputs.labels, dim=0), + "imgs_info": inputs.imgs_info, + "mode": mode, + } + + def _customize_outputs( + self, + outputs: Any, # noqa: ANN401 + inputs: MulticlassClsBatchDataEntity, + ) -> MulticlassClsBatchPredEntity | OTXBatchLossEntity: + if self.training: + return OTXBatchLossEntity(loss=outputs) + + if self.explain_mode: + return MulticlassClsBatchPredEntity( + batch_size=inputs.batch_size, + images=inputs.images, + imgs_info=inputs.imgs_info, + scores=outputs["scores"], + labels=outputs["labels"], + saliency_map=outputs["saliency_map"], + feature_vector=outputs["feature_vector"], + ) + + # To list, batch-wise + logits = outputs if isinstance(outputs, torch.Tensor) else outputs["logits"] + scores = torch.unbind(logits, 0) + preds = logits.argmax(-1, keepdim=True).unbind(0) + + return MulticlassClsBatchPredEntity( + batch_size=inputs.batch_size, + images=inputs.stacked_images, + imgs_info=inputs.imgs_info, + scores=scores, + labels=preds, + ) + + def training_step(self, batch: MulticlassClsBatchDataEntity, batch_idx: int) -> Tensor: + """Performs a single training step on a batch of data.""" + loss = super().training_step(batch, batch_idx) + # Collect metrics related to Semi-SL Training. + if self.train_type == OTXTrainType.SEMI_SUPERVISED: + if hasattr(self.model.head, "unlabeled_coef"): + self.log( + "train/unlabeled_coef", + self.model.head.unlabeled_coef, + on_step=True, + on_epoch=False, + prog_bar=True, + ) + if hasattr(self.model.head, "num_pseudo_label"): + self.log( + "train/num_pseudo_label", + self.model.head.num_pseudo_label, + on_step=True, + on_epoch=False, + prog_bar=True, + ) + return loss + @property def _export_parameters(self) -> TaskLevelExportParameters: """Defines parameters required to export a particular model implementation.""" @@ -66,6 +155,22 @@ def _export_parameters(self) -> TaskLevelExportParameters: hierarchical=False, ) + @property + def _exporter(self) -> OTXModelExporter: + """Creates OTXModelExporter object that can export the model.""" + return OTXNativeModelExporter( + task_level_export_parameters=self._export_parameters, + input_size=self.image_size, + mean=(123.675, 116.28, 103.53), + std=(58.395, 57.12, 57.375), + resize_mode="standard", + pad_value=0, + swap_rgb=False, + via_onnx=False, + onnx_export_configuration=None, + output_names=["logits", "feature_vector", "saliency_map"] if self.explain_mode else None, + ) + def _convert_pred_entity_to_compute_metric( self, preds: MulticlassClsBatchPredEntity, @@ -78,12 +183,19 @@ def _convert_pred_entity_to_compute_metric( "target": target, } + def _reset_prediction_layer(self, num_classes: int) -> None: + return + def get_dummy_input(self, batch_size: int = 1) -> MulticlassClsBatchDataEntity: """Returns a dummy input for classification model.""" images = [torch.rand(*self.image_size[1:]) for _ in range(batch_size)] labels = [torch.LongTensor([0])] * batch_size return MulticlassClsBatchDataEntity(batch_size, images, [], labels=labels) + def forward_for_tracing(self, image: Tensor) -> Tensor | dict[str, Tensor]: + """Model forward function used for the model tracing during model exportation.""" + return self.model(images=image) + ### NOTE, currently, although we've made the separate Multi-cls, Multi-label classes ### It'll be integrated after H-label classification integration with more advanced design. @@ -109,6 +221,52 @@ def __init__( ) self.image_size = (1, 3, 224, 224) + def _customize_inputs(self, inputs: MultilabelClsBatchDataEntity) -> dict[str, Any]: + if self.training: + mode = "loss" + elif self.explain_mode: + mode = "explain" + else: + mode = "predict" + + return { + "images": inputs.stacked_images, + "labels": torch.stack(inputs.labels), + "imgs_info": inputs.imgs_info, + "mode": mode, + } + + def _customize_outputs( + self, + outputs: Any, # noqa: ANN401 + inputs: MultilabelClsBatchDataEntity, + ) -> MultilabelClsBatchPredEntity | OTXBatchLossEntity: + if self.training: + return OTXBatchLossEntity(loss=outputs) + + if self.explain_mode: + return MultilabelClsBatchPredEntity( + batch_size=inputs.batch_size, + images=inputs.images, + imgs_info=inputs.imgs_info, + scores=outputs["scores"], + labels=outputs["labels"], + saliency_map=outputs["saliency_map"], + feature_vector=outputs["feature_vector"], + ) + + # To list, batch-wise + logits = outputs if isinstance(outputs, torch.Tensor) else outputs["logits"] + scores = torch.unbind(logits, 0) + + return MultilabelClsBatchPredEntity( + batch_size=inputs.batch_size, + images=inputs.images, + imgs_info=inputs.imgs_info, + scores=scores, + labels=logits.argmax(-1, keepdim=True).unbind(0), + ) + @property def _export_parameters(self) -> TaskLevelExportParameters: """Defines parameters required to export a particular model implementation.""" @@ -120,6 +278,22 @@ def _export_parameters(self) -> TaskLevelExportParameters: confidence_threshold=0.5, ) + @property + def _exporter(self) -> OTXModelExporter: + """Creates OTXModelExporter object that can export the model.""" + return OTXNativeModelExporter( + task_level_export_parameters=self._export_parameters, + input_size=self.image_size, + mean=(123.675, 116.28, 103.53), + std=(58.395, 57.12, 57.375), + resize_mode="standard", + pad_value=0, + swap_rgb=False, + via_onnx=False, + onnx_export_configuration=None, + output_names=["logits", "feature_vector", "saliency_map"] if self.explain_mode else None, + ) + def _convert_pred_entity_to_compute_metric( self, preds: MultilabelClsBatchPredEntity, @@ -132,7 +306,7 @@ def _convert_pred_entity_to_compute_metric( def forward_for_tracing(self, image: Tensor) -> Tensor | dict[str, Tensor]: """Model forward function used for the model tracing during model exportation.""" - return self.model.forward(image, mode="tensor") + return self.model.forward(image) def get_dummy_input(self, batch_size: int = 1) -> MultilabelClsBatchDataEntity: """Returns a dummy input for classification OV model.""" @@ -144,6 +318,8 @@ def get_dummy_input(self, batch_size: int = 1) -> MultilabelClsBatchDataEntity: class OTXHlabelClsModel(OTXModel[HlabelClsBatchDataEntity, HlabelClsBatchPredEntity]): """H-label classification models used in OTX.""" + label_info: HLabelInfo + def __init__( self, label_info: HLabelInfo, @@ -161,6 +337,56 @@ def __init__( ) self.image_size = (1, 3, 224, 224) + def _customize_inputs(self, inputs: HlabelClsBatchDataEntity) -> dict[str, Any]: + if self.training: + mode = "loss" + elif self.explain_mode: + mode = "explain" + else: + mode = "predict" + + return { + "images": inputs.stacked_images, + "labels": torch.stack(inputs.labels), + "imgs_info": inputs.imgs_info, + "mode": mode, + } + + def _customize_outputs( + self, + outputs: Any, # noqa: ANN401 + inputs: HlabelClsBatchDataEntity, + ) -> HlabelClsBatchPredEntity | OTXBatchLossEntity: + if self.training: + return OTXBatchLossEntity(loss=outputs) + + # To list, batch-wise + if isinstance(outputs, dict): + scores = outputs["scores"] + labels = outputs["labels"] + else: + scores = outputs + labels = outputs.argmax(-1, keepdim=True) + + if self.explain_mode: + return HlabelClsBatchPredEntity( + batch_size=inputs.batch_size, + images=inputs.images, + imgs_info=inputs.imgs_info, + scores=scores, + labels=labels, + saliency_map=outputs["saliency_map"], + feature_vector=outputs["feature_vector"], + ) + + return HlabelClsBatchPredEntity( + batch_size=inputs.batch_size, + images=inputs.images, + imgs_info=inputs.imgs_info, + scores=scores, + labels=labels, + ) + @property def _export_parameters(self) -> TaskLevelExportParameters: """Defines parameters required to export a particular model implementation.""" @@ -172,6 +398,22 @@ def _export_parameters(self) -> TaskLevelExportParameters: confidence_threshold=0.5, ) + @property + def _exporter(self) -> OTXModelExporter: + """Creates OTXModelExporter object that can export the model.""" + return OTXNativeModelExporter( + task_level_export_parameters=self._export_parameters, + input_size=self.image_size, + mean=(123.675, 116.28, 103.53), + std=(58.395, 57.12, 57.375), + resize_mode="standard", + pad_value=0, + swap_rgb=False, + via_onnx=False, + onnx_export_configuration=None, + output_names=["logits", "feature_vector", "saliency_map"] if self.explain_mode else None, + ) + def _convert_pred_entity_to_compute_metric( self, preds: HlabelClsBatchPredEntity, @@ -179,12 +421,14 @@ def _convert_pred_entity_to_compute_metric( ) -> MetricInput: hlabel_info: HLabelInfo = self.label_info # type: ignore[assignment] + _labels = torch.stack(preds.labels) if isinstance(preds.labels, list) else preds.labels + _scores = torch.stack(preds.scores) if isinstance(preds.scores, list) else preds.scores if hlabel_info.num_multilabel_classes > 0: - preds_multiclass = torch.stack(preds.labels)[:, : hlabel_info.num_multiclass_heads] - preds_multilabel = torch.stack(preds.scores)[:, hlabel_info.num_multiclass_heads :] + preds_multiclass = _labels[:, : hlabel_info.num_multiclass_heads] + preds_multilabel = _scores[:, hlabel_info.num_multiclass_heads :] pred_result = torch.cat([preds_multiclass, preds_multilabel], dim=1) else: - pred_result = torch.stack(preds.labels) + pred_result = _labels return { "preds": pred_result, "target": torch.stack(inputs.labels), @@ -203,6 +447,10 @@ def get_dummy_input(self, batch_size: int = 1) -> HlabelClsBatchDataEntity: labels = [torch.LongTensor([0])] * batch_size return HlabelClsBatchDataEntity(batch_size, images, [], labels=labels) + def forward_for_tracing(self, image: Tensor) -> Tensor | dict[str, Tensor]: + """Model forward function used for the model tracing during model exportation.""" + return self.model(images=image) + class OVMulticlassClassificationModel( OVModel[MulticlassClsBatchDataEntity, MulticlassClsBatchPredEntity], diff --git a/src/otx/recipe/classification/h_label_cls/efficientnet_v2.yaml b/src/otx/recipe/classification/h_label_cls/efficientnet_v2.yaml index ade39a63c2c..848a985d433 100644 --- a/src/otx/recipe/classification/h_label_cls/efficientnet_v2.yaml +++ b/src/otx/recipe/classification/h_label_cls/efficientnet_v2.yaml @@ -1,6 +1,8 @@ model: - class_path: otx.algo.classification.efficientnet_v2.EfficientNetV2ForHLabelCls + class_path: otx.algo.classification.timm_model.TimmModelForHLabelCls init_args: + backbone: efficientnetv2_s_21k + optimizer: class_path: torch.optim.SGD init_args: diff --git a/src/otx/recipe/classification/multi_class_cls/efficientnet_v2.yaml b/src/otx/recipe/classification/multi_class_cls/efficientnet_v2.yaml index 71bf5a929df..2ca3c354f73 100644 --- a/src/otx/recipe/classification/multi_class_cls/efficientnet_v2.yaml +++ b/src/otx/recipe/classification/multi_class_cls/efficientnet_v2.yaml @@ -1,7 +1,8 @@ model: - class_path: otx.algo.classification.efficientnet_v2.EfficientNetV2ForMulticlassCls + class_path: otx.algo.classification.timm_model.TimmModelForMulticlassCls init_args: label_info: 1000 + backbone: efficientnetv2_s_21k optimizer: class_path: torch.optim.SGD diff --git a/src/otx/recipe/classification/multi_class_cls/semisl/deit_tiny_semisl.yaml b/src/otx/recipe/classification/multi_class_cls/semisl/deit_tiny_semisl.yaml index 59a65005092..34155bf6089 100644 --- a/src/otx/recipe/classification/multi_class_cls/semisl/deit_tiny_semisl.yaml +++ b/src/otx/recipe/classification/multi_class_cls/semisl/deit_tiny_semisl.yaml @@ -1,8 +1,9 @@ model: - class_path: otx.algo.classification.vit.VisionTransformerForMulticlassClsSemiSL + class_path: otx.algo.classification.vit.VisionTransformerForMulticlassCls init_args: label_info: 1000 arch: "vit-tiny" + train_type: SEMI_SUPERVISED optimizer: class_path: torch.optim.AdamW diff --git a/src/otx/recipe/classification/multi_class_cls/semisl/dino_v2_semisl.yaml b/src/otx/recipe/classification/multi_class_cls/semisl/dino_v2_semisl.yaml index beebb234306..9b06bfb51e1 100644 --- a/src/otx/recipe/classification/multi_class_cls/semisl/dino_v2_semisl.yaml +++ b/src/otx/recipe/classification/multi_class_cls/semisl/dino_v2_semisl.yaml @@ -1,22 +1,23 @@ model: - class_path: otx.algo.classification.dino_v2.DINOv2ForMulticlassClsSemiSL + class_path: otx.algo.classification.vit.VisionTransformerForMulticlassCls init_args: label_info: 1000 - freeze_backbone: false - backbone: dinov2_vits14_reg + arch: "dinov2-small" + train_type: SEMI_SUPERVISED optimizer: class_path: torch.optim.AdamW init_args: - lr: 1e-5 + lr: 0.0001 + weight_decay: 0.05 scheduler: class_path: lightning.pytorch.cli.ReduceLROnPlateau init_args: - mode: min + mode: max factor: 0.5 - patience: 9 - monitor: train/loss + patience: 1 + monitor: val/accuracy engine: task: MULTI_CLASS_CLS diff --git a/src/otx/recipe/classification/multi_class_cls/semisl/efficientnet_b0_semisl.yaml b/src/otx/recipe/classification/multi_class_cls/semisl/efficientnet_b0_semisl.yaml index 6128ae0dd2a..30126060f6e 100644 --- a/src/otx/recipe/classification/multi_class_cls/semisl/efficientnet_b0_semisl.yaml +++ b/src/otx/recipe/classification/multi_class_cls/semisl/efficientnet_b0_semisl.yaml @@ -1,8 +1,9 @@ model: - class_path: otx.algo.classification.efficientnet.EfficientNetForMulticlassClsSemiSL + class_path: otx.algo.classification.efficientnet.EfficientNetForMulticlassCls init_args: label_info: 1000 version: b0 + train_type: SEMI_SUPERVISED optimizer: class_path: torch.optim.SGD diff --git a/src/otx/recipe/classification/multi_class_cls/semisl/efficientnet_v2_semisl.yaml b/src/otx/recipe/classification/multi_class_cls/semisl/efficientnet_v2_semisl.yaml index 6bcbda09a19..b1f87665bde 100644 --- a/src/otx/recipe/classification/multi_class_cls/semisl/efficientnet_v2_semisl.yaml +++ b/src/otx/recipe/classification/multi_class_cls/semisl/efficientnet_v2_semisl.yaml @@ -1,7 +1,9 @@ model: - class_path: otx.algo.classification.efficientnet_v2.EfficientNetV2ForMulticlassClsSemiSL + class_path: otx.algo.classification.timm_model.TimmModelForMulticlassCls init_args: label_info: 1000 + backbone: efficientnetv2_s_21k + train_type: SEMI_SUPERVISED optimizer: class_path: torch.optim.SGD diff --git a/src/otx/recipe/classification/multi_class_cls/semisl/mobilenet_v3_large_semisl.yaml b/src/otx/recipe/classification/multi_class_cls/semisl/mobilenet_v3_large_semisl.yaml index 55720c17897..e0dab91d687 100644 --- a/src/otx/recipe/classification/multi_class_cls/semisl/mobilenet_v3_large_semisl.yaml +++ b/src/otx/recipe/classification/multi_class_cls/semisl/mobilenet_v3_large_semisl.yaml @@ -1,8 +1,9 @@ model: - class_path: otx.algo.classification.mobilenet_v3.MobileNetV3ForMulticlassClsSemiSL + class_path: otx.algo.classification.mobilenet_v3.MobileNetV3ForMulticlassCls init_args: mode: large label_info: 1000 + train_type: SEMI_SUPERVISED optimizer: class_path: torch.optim.SGD diff --git a/src/otx/recipe/classification/multi_label_cls/efficientnet_v2.yaml b/src/otx/recipe/classification/multi_label_cls/efficientnet_v2.yaml index df4e1d76730..cc6ec415ec2 100644 --- a/src/otx/recipe/classification/multi_label_cls/efficientnet_v2.yaml +++ b/src/otx/recipe/classification/multi_label_cls/efficientnet_v2.yaml @@ -1,7 +1,8 @@ model: - class_path: otx.algo.classification.efficientnet_v2.EfficientNetV2ForMultilabelCls + class_path: otx.algo.classification.timm_model.TimmModelForMultilabelCls init_args: label_info: 1000 + backbone: efficientnetv2_s_21k optimizer: class_path: torch.optim.SGD diff --git a/tests/unit/algo/callbacks/test_unlabeled_loss_warmup.py b/tests/unit/algo/callbacks/test_unlabeled_loss_warmup.py index 54ea262ac71..d490738098a 100644 --- a/tests/unit/algo/callbacks/test_unlabeled_loss_warmup.py +++ b/tests/unit/algo/callbacks/test_unlabeled_loss_warmup.py @@ -3,12 +3,12 @@ import pytest from lightning import Trainer from otx.algo.callbacks.unlabeled_loss_warmup import UnlabeledLossWarmUpCallback -from otx.algo.classification.efficientnet import EfficientNetForMulticlassClsSemiSL +from otx.algo.classification.efficientnet import EfficientNetForMulticlassCls @pytest.fixture() def fxt_semisl_model(): - return EfficientNetForMulticlassClsSemiSL(10) + return EfficientNetForMulticlassCls(10, train_type="SEMI_SUPERVISED") def test_unlabeled_loss_warmup_callback(mocker, fxt_semisl_model): diff --git a/tests/unit/algo/classification/test_otx_dino_v2.py b/tests/unit/algo/classification/test_otx_dino_v2.py deleted file mode 100644 index 2ffa177976c..00000000000 --- a/tests/unit/algo/classification/test_otx_dino_v2.py +++ /dev/null @@ -1,70 +0,0 @@ -# Copyright (C) 2023 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 - -from __future__ import annotations - -from unittest.mock import MagicMock, patch - -import pytest -import torch -from otx.algo.classification.dino_v2 import DINOv2, DINOv2RegisterClassifier -from otx.core.data.entity.base import OTXBatchLossEntity -from otx.core.data.entity.classification import MulticlassClsBatchPredEntity - - -class TestDINOv2: - @pytest.fixture() - def model_freeze_backbone(self) -> None: - mock_backbone = MagicMock() - mock_backbone.return_value = torch.randn(1, 12) - - with patch("torch.hub.load", autospec=True) as mock_load: - mock_load.return_value = mock_backbone - - return DINOv2( - backbone="dinov2_vits14_reg", - freeze_backbone=True, - head_in_channels=12, - num_classes=2, - ) - - def test_freeze_backbone(self, model_freeze_backbone) -> None: - for _, v in model_freeze_backbone.backbone.named_parameters(): - assert v.requires_grad is False - - def test_forward(self, model_freeze_backbone) -> None: - rand_img = torch.randn((1, 3, 224, 224), dtype=torch.float32) - rand_label = torch.ones((1), dtype=torch.int64) - outputs = model_freeze_backbone(rand_img, rand_label) - assert isinstance(outputs, torch.Tensor) - - -class TestDINOv2RegisterClassifier: - @pytest.fixture() - def otx_model(self) -> DINOv2RegisterClassifier: - return DINOv2RegisterClassifier(label_info=1) - - def test_create_model(self, otx_model): - assert isinstance(otx_model.model, DINOv2) - - def test_customize_inputs(self, otx_model, fxt_multiclass_cls_batch_data_entity): - outputs = otx_model._customize_inputs(fxt_multiclass_cls_batch_data_entity) - assert "imgs" in outputs - assert "labels" in outputs - assert "imgs_info" in outputs - - def test_customize_outputs(self, otx_model, fxt_multiclass_cls_batch_data_entity): - outputs = torch.randn(2, 10) - otx_model.training = True - preds = otx_model._customize_outputs(outputs, fxt_multiclass_cls_batch_data_entity) - assert isinstance(preds, OTXBatchLossEntity) - - otx_model.training = False - preds = otx_model._customize_outputs(outputs, fxt_multiclass_cls_batch_data_entity) - assert isinstance(preds, MulticlassClsBatchPredEntity) - - def test_predict_step(self, otx_model, fxt_multiclass_cls_batch_data_entity): - otx_model.eval() - outputs = otx_model.predict_step(batch=fxt_multiclass_cls_batch_data_entity, batch_idx=0) - - assert isinstance(outputs, MulticlassClsBatchPredEntity) diff --git a/tests/unit/algo/classification/test_efficientnet_v2.py b/tests/unit/algo/classification/test_timm_model.py similarity index 91% rename from tests/unit/algo/classification/test_efficientnet_v2.py rename to tests/unit/algo/classification/test_timm_model.py index 52c3182809c..fbb4d6fbbc0 100644 --- a/tests/unit/algo/classification/test_efficientnet_v2.py +++ b/tests/unit/algo/classification/test_timm_model.py @@ -4,10 +4,10 @@ import pytest import torch from otx.algo.classification.classifier import ImageClassifier -from otx.algo.classification.efficientnet_v2 import ( - EfficientNetV2ForHLabelCls, - EfficientNetV2ForMulticlassCls, - EfficientNetV2ForMultilabelCls, +from otx.algo.classification.timm_model import ( + TimmModelForHLabelCls, + TimmModelForMulticlassCls, + TimmModelForMultilabelCls, ) from otx.core.data.entity.base import OTXBatchLossEntity from otx.core.data.entity.classification import ( @@ -19,12 +19,13 @@ @pytest.fixture() def fxt_multi_class_cls_model(): - return EfficientNetV2ForMulticlassCls( + return TimmModelForMulticlassCls( label_info=10, + backbone="efficientnetv2_s_21k", ) -class TestEfficientNetV2ForMulticlassCls: +class TestTimmModelForMulticlassCls: def test_create_model(self, fxt_multi_class_cls_model): assert isinstance(fxt_multi_class_cls_model.model, ImageClassifier) @@ -56,12 +57,13 @@ def test_predict_step(self, fxt_multi_class_cls_model, fxt_multiclass_cls_batch_ @pytest.fixture() def fxt_multi_label_cls_model(): - return EfficientNetV2ForMultilabelCls( + return TimmModelForMultilabelCls( label_info=10, + backbone="efficientnetv2_s_21k", ) -class TestEfficientNetV2ForMultilabelCls: +class TestTimmModelForMultilabelCls: def test_create_model(self, fxt_multi_label_cls_model): assert isinstance(fxt_multi_label_cls_model.model, ImageClassifier) @@ -93,12 +95,13 @@ def test_predict_step(self, fxt_multi_label_cls_model, fxt_multilabel_cls_batch_ @pytest.fixture() def fxt_h_label_cls_model(fxt_hlabel_data): - return EfficientNetV2ForHLabelCls( + return TimmModelForHLabelCls( label_info=fxt_hlabel_data, + backbone="efficientnetv2_s_21k", ) -class TestEfficientNetV2ForHLabelCls: +class TestTimmModelForHLabelCls: def test_create_model(self, fxt_h_label_cls_model): assert isinstance(fxt_h_label_cls_model.model, ImageClassifier)