Skip to content

Commit

Permalink
Add num_classes to OTXModel and revisit DataMetaInfo
Browse files Browse the repository at this point in the history
Signed-off-by: Kim, Vinnam <vinnam.kim@intel.com>
  • Loading branch information
vinnamkim committed Jan 11, 2024
1 parent 9d5818a commit bc6e7fc
Show file tree
Hide file tree
Showing 26 changed files with 183 additions and 82 deletions.
4 changes: 2 additions & 2 deletions src/otx/algo/classification/otx_dino_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ def forward(self, imgs: torch.Tensor, labels: torch.Tensor = None) -> torch.Tens
class DINOv2RegisterClassifier(OTXMulticlassClsModel):
"""DINO-v2 Classification Model with register."""

def __init__(self, config: DictConfig) -> None:
def __init__(self, num_classes: int, config: DictConfig) -> None:
self.config = config
super().__init__() # create the model
super().__init__(num_classes=num_classes) # create the model

def _create_model(self) -> nn.Module:
"""Create the model."""
Expand Down
3 changes: 3 additions & 0 deletions src/otx/config/model/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,6 @@ scheduler:
mode: min
factor: 0.1
patience: 10

otx_model:
num_classes: ???
3 changes: 3 additions & 0 deletions src/otx/config/model/mmdet_inst_seg.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
defaults:
- default

_target_: otx.core.model.module.instance_segmentation.OTXInstanceSegLitModule

optimizer:
Expand Down
1 change: 1 addition & 0 deletions src/otx/config/model/mmseg.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ scheduler:
otx_model:
_target_: otx.core.model.entity.segmentation.MMSegCompatibleModel
config: ???
num_classes: ???

# compile model for faster training with pytorch 2.0
torch_compile: false
26 changes: 19 additions & 7 deletions src/otx/core/data/dataset/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,27 @@


@dataclass
class DataMetaInfo:
"""Meta information of each subset datasets."""
class LabelInfo:
"""Object to represent label information."""

class_names: list[str]
label_names: list[str]

@property
def num_classes(self) -> int:
"""Return number of classes."""
return len(self.class_names)
"""Return number of labels."""
return len(self.label_names)

@classmethod
def from_num_classes(cls, num_classes: int) -> LabelInfo:
"""Create this object from the number of classes.
Args:
num_classes: Number of classes
Returns:
LabelInfo(label_names=["label_0", ...])
"""
return LabelInfo(label_names=[f"label_{idx}" for idx in range(num_classes)])


class OTXDataset(Dataset, Generic[T_OTXDataEntity]):
Expand All @@ -59,8 +71,8 @@ def __init__(
self.mem_cache_img_max_size = mem_cache_img_max_size
self.max_refetch = max_refetch

self.meta_info = DataMetaInfo(
class_names=[category.name for category in self.dm_subset.categories()[AnnotationType.label]],
self.meta_info = LabelInfo(
label_names=[category.name for category in self.dm_subset.categories()[AnnotationType.label]],
)

def __len__(self) -> int:
Expand Down
6 changes: 3 additions & 3 deletions src/otx/core/data/dataset/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from datumaro.components.annotation import AnnotationType
from torch.nn import functional

from otx.core.data.dataset.base import DataMetaInfo, OTXDataset
from otx.core.data.dataset.base import LabelInfo, OTXDataset
from otx.core.data.entity.base import ImageInfo
from otx.core.data.entity.classification import (
HlabelClsBatchDataEntity,
Expand All @@ -27,7 +27,7 @@


@dataclass
class HLabelMetaInfo(DataMetaInfo):
class HLabelMetaInfo(LabelInfo):
"""Meta information of hlabel classification."""

hlabel_info: HLabelInfo
Expand Down Expand Up @@ -110,7 +110,7 @@ def __init__(self, **kwargs) -> None:

# Hlabel classification used HLabelMetaInfo to insert the HLabelInfo.
self.meta_info = HLabelMetaInfo(
class_names=[category.name for category in self.dm_categories],
label_names=[category.name for category in self.dm_categories],
hlabel_info=HLabelInfo.from_dm_label_groups(self.dm_categories),
)

Expand Down
6 changes: 3 additions & 3 deletions src/otx/core/data/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from omegaconf import DictConfig, OmegaConf
from torch.utils.data import DataLoader

from otx.core.data.dataset.base import DataMetaInfo
from otx.core.data.dataset.base import LabelInfo
from otx.core.data.factory import OTXDatasetFactory
from otx.core.data.mem_cache import (
MemCacheHandlerSingleton,
Expand Down Expand Up @@ -68,7 +68,7 @@ def __init__(
mem_size=mem_size,
)

meta_infos: list[DataMetaInfo] = []
meta_infos: list[LabelInfo] = []
for name, dm_subset in dataset.subsets().items():
if name not in config_mapping:
log.warning(f"{name} is not available. Skip it")
Expand All @@ -91,7 +91,7 @@ def __init__(

self.meta_info = next(iter(meta_infos))

def _is_meta_info_valid(self, meta_infos: list[DataMetaInfo]) -> bool:
def _is_meta_info_valid(self, meta_infos: list[LabelInfo]) -> bool:
"""Check whether there are mismatches in the metainfo for the all subsets."""
if all(meta_info == meta_infos[0] for meta_info in meta_infos):
return True
Expand Down
6 changes: 4 additions & 2 deletions src/otx/core/model/entity/action_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from otx.core.data.entity.base import OTXBatchLossEntity
from otx.core.model.entity.base import OTXModel
from otx.core.utils.build import build_mm_model, get_classification_layers
from otx.core.utils.config import inplace_num_classes

if TYPE_CHECKING:
from omegaconf import DictConfig
Expand All @@ -32,10 +33,11 @@ class MMActionCompatibleModel(OTXActionClsModel):
compatible for OTX pipelines.
"""

def __init__(self, config: DictConfig) -> None:
def __init__(self, num_classes: int, config: DictConfig) -> None:
config = inplace_num_classes(cfg=config, num_classes=num_classes)
self.config = config
self.load_from = config.pop("load_from", None)
super().__init__()
super().__init__(num_classes=num_classes)

def _create_model(self) -> nn.Module:
from mmaction.models.data_preprocessors import (
Expand Down
6 changes: 4 additions & 2 deletions src/otx/core/model/entity/action_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from otx.core.data.entity.base import OTXBatchLossEntity
from otx.core.model.entity.base import OTXModel
from otx.core.utils.build import build_mm_model, get_classification_layers
from otx.core.utils.config import inplace_num_classes

if TYPE_CHECKING:
from omegaconf import DictConfig
Expand All @@ -31,10 +32,11 @@ class MMActionCompatibleModel(OTXActionDetModel):
compatible for OTX pipelines.
"""

def __init__(self, config: DictConfig) -> None:
def __init__(self, num_classes: int, config: DictConfig) -> None:
config = inplace_num_classes(cfg=config, num_classes=num_classes)
self.config = config
self.load_from = config.pop("load_from", None)
super().__init__()
super().__init__(num_classes=num_classes)

def _create_model(self) -> nn.Module:
from mmaction.models.data_preprocessors import (
Expand Down
46 changes: 44 additions & 2 deletions src/otx/core/model/entity/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@

from __future__ import annotations

import warnings
from abc import abstractmethod
from typing import TYPE_CHECKING, Any, Generic

from torch import nn

from otx.core.data.dataset.base import LabelInfo
from otx.core.data.entity.base import (
OTXBatchLossEntity,
T_OTXBatchDataEntity,
Expand All @@ -24,13 +26,45 @@


class OTXModel(nn.Module, Generic[T_OTXBatchDataEntity, T_OTXBatchPredEntity]):
"""Base class for the models used in OTX."""
"""Base class for the models used in OTX.
def __init__(self) -> None:
Args:
num_classes: Number of classes this model can predict.
"""

def __init__(self, num_classes: int) -> None:
super().__init__()

self._label_info = LabelInfo.from_num_classes(num_classes)
self.classification_layers: dict[str, dict[str, Any]] = {}
self.model = self._create_model()

@property
def label_info(self) -> LabelInfo:
"""Get this model label information."""
return self._label_info

@label_info.setter
def label_info(self, label_info: LabelInfo | list[str]) -> None:
"""Set this model label information."""
if isinstance(label_info, list):
label_info = LabelInfo(label_names=label_info)

old_num_classes = self._label_info.num_classes
new_num_classes = label_info.num_classes

if old_num_classes != new_num_classes:
msg = (
f"Given LabelInfo has the different number of classes "
f"({old_num_classes}!={new_num_classes}). "
"The model prediction layer is reset to the new number of classes "
f"(={new_num_classes})."
)
warnings.warn(msg, stacklevel=0)
self._reset_prediction_layer(num_classes=label_info.num_classes)

self._label_info = label_info

@abstractmethod
def _create_model(self) -> nn.Module:
"""Create a PyTorch model for this class."""
Expand Down Expand Up @@ -164,3 +198,11 @@ def register_explain_hook(self) -> None:
TBD
"""
raise NotImplementedError

def _reset_prediction_layer(self, num_classes: int) -> None:
"""Reset its prediction layer with a given number of classes.
Args:
num_classes: Number of classes
"""
raise NotImplementedError
21 changes: 13 additions & 8 deletions src/otx/core/model/entity/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
)
from otx.core.model.entity.base import OTXModel
from otx.core.utils.build import build_mm_model, get_classification_layers
from otx.core.utils.config import inplace_num_classes

if TYPE_CHECKING:
from mmpretrain.models.utils import ClsDataPreprocessor
Expand Down Expand Up @@ -63,10 +64,11 @@ class MMPretrainMulticlassClsModel(OTXMulticlassClsModel):
compatible for OTX pipelines.
"""

def __init__(self, config: DictConfig) -> None:
def __init__(self, num_classes: int, config: DictConfig) -> None:
config = inplace_num_classes(cfg=config, num_classes=num_classes)
self.config = config
self.load_from = config.pop("load_from", None)
super().__init__()
super().__init__(num_classes=num_classes)

def _create_model(self) -> nn.Module:
model, classification_layers = _create_mmpretrain_model(self.config, self.load_from)
Expand Down Expand Up @@ -155,10 +157,11 @@ class MMPretrainMultilabelClsModel(OTXMultilabelClsModel):
compatible for OTX pipelines.
"""

def __init__(self, config: DictConfig) -> None:
def __init__(self, num_classes: int, config: DictConfig) -> None:
config = inplace_num_classes(cfg=config, num_classes=num_classes)
self.config = config
self.load_from = config.pop("load_from", None)
super().__init__()
super().__init__(num_classes=num_classes)

def _create_model(self) -> nn.Module:
model, classification_layers = _create_mmpretrain_model(self.config, self.load_from)
Expand Down Expand Up @@ -241,10 +244,11 @@ class MMPretrainHlabelClsModel(OTXHlabelClsModel):
compatible for OTX pipelines.
"""

def __init__(self, config: DictConfig) -> None:
def __init__(self, num_classes: int, config: DictConfig) -> None:
config = inplace_num_classes(cfg=config, num_classes=num_classes)
self.config = config
self.load_from = config.pop("load_from", None)
super().__init__()
super().__init__(num_classes=num_classes)

def _create_model(self) -> nn.Module:
model, classification_layers = _create_mmpretrain_model(self.config, self.load_from)
Expand Down Expand Up @@ -322,10 +326,11 @@ class OVClassificationCompatibleModel(OTXMulticlassClsModel):
and create the OTX classification model compatible for OTX testing pipeline.
"""

def __init__(self, config: DictConfig) -> None:
def __init__(self, num_classes: int, config: DictConfig) -> None:
self.model_name = config.pop("model_name")
config = inplace_num_classes(cfg=config, num_classes=num_classes)
self.config = config
super().__init__()
super().__init__(num_classes=num_classes)

def _create_model(self) -> nn.Module:
from openvino.model_api.models import ClassificationModel
Expand Down
10 changes: 6 additions & 4 deletions src/otx/core/model/entity/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from otx.core.data.entity.detection import DetBatchDataEntity, DetBatchPredEntity
from otx.core.model.entity.base import OTXModel
from otx.core.utils.build import build_mm_model, get_classification_layers
from otx.core.utils.config import inplace_num_classes

if TYPE_CHECKING:
from mmdet.models.data_preprocessors import DetDataPreprocessor
Expand All @@ -35,10 +36,11 @@ class MMDetCompatibleModel(OTXDetectionModel):
compatible for OTX pipelines.
"""

def __init__(self, config: DictConfig) -> None:
def __init__(self, num_classes: int, config: DictConfig) -> None:
config = inplace_num_classes(cfg=config, num_classes=num_classes)
self.config = config
self.load_from = config.pop("load_from", None)
super().__init__()
super().__init__(num_classes=num_classes)

def _create_model(self) -> nn.Module:
from mmdet.models.data_preprocessors import (
Expand Down Expand Up @@ -154,10 +156,10 @@ class OVDetectionCompatibleModel(OTXDetectionModel):
and create the OTX detection model compatible for OTX testing pipeline.
"""

def __init__(self, config: DictConfig) -> None:
def __init__(self, num_classes: int, config: DictConfig) -> None:
self.model_name = config.pop("model_name")
self.config = config
super().__init__()
super().__init__(num_classes=num_classes)

def _create_model(self) -> nn.Module:
from openvino.model_api.models import DetectionModel
Expand Down
11 changes: 7 additions & 4 deletions src/otx/core/model/entity/instance_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
)
from otx.core.model.entity.base import OTXModel
from otx.core.utils.build import build_mm_model, get_classification_layers
from otx.core.utils.config import inplace_num_classes

if TYPE_CHECKING:
from mmdet.models.data_preprocessors import DetDataPreprocessor
Expand All @@ -34,10 +35,11 @@ class OTXInstanceSegModel(
class MMDetInstanceSegCompatibleModel(OTXInstanceSegModel):
"""Instance Segmentation model compatible for MMDet."""

def __init__(self, config: DictConfig) -> None:
def __init__(self, num_classes: int, config: DictConfig) -> None:
config = inplace_num_classes(cfg=config, num_classes=num_classes)
self.config = config
self.load_from = self.config.pop("load_from", None)
super().__init__()
super().__init__(num_classes=num_classes)

def _create_model(self) -> nn.Module:
from mmdet.models.data_preprocessors import (
Expand Down Expand Up @@ -176,11 +178,12 @@ class OVInstanceSegCompatibleModel(OTXInstanceSegModel):
and create the OTX detection model compatible for OTX testing pipeline.
"""

def __init__(self, config: DictConfig) -> None:
def __init__(self, num_classes: int, config: DictConfig) -> None:
self.model_name = config.pop("model_name")
self.model_type = config.pop("model_type")
config = inplace_num_classes(cfg=config, num_classes=num_classes)
self.config = config
super().__init__()
super().__init__(num_classes=num_classes)

def _create_model(self) -> nn.Module:
from openvino.model_api.models import Model
Expand Down
Loading

0 comments on commit bc6e7fc

Please sign in to comment.