Skip to content

Commit

Permalink
Pre-work for the design change: create own python class for every mod…
Browse files Browse the repository at this point in the history
…el (#2775)

Signed-off-by: Kim, Vinnam <vinnam.kim@intel.com>
  • Loading branch information
vinnamkim authored Jan 11, 2024
1 parent 0e8bd95 commit 02c3938
Show file tree
Hide file tree
Showing 28 changed files with 260 additions and 83 deletions.
4 changes: 2 additions & 2 deletions src/otx/algo/classification/otx_dino_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ def forward(self, imgs: torch.Tensor, labels: torch.Tensor = None) -> torch.Tens
class DINOv2RegisterClassifier(OTXMulticlassClsModel):
"""DINO-v2 Classification Model with register."""

def __init__(self, config: DictConfig) -> None:
def __init__(self, num_classes: int, config: DictConfig) -> None:
self.config = config
super().__init__() # create the model
super().__init__(num_classes=num_classes) # create the model

def _create_model(self) -> nn.Module:
"""Create the model."""
Expand Down
3 changes: 3 additions & 0 deletions src/otx/config/model/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,6 @@ scheduler:
mode: min
factor: 0.1
patience: 10

otx_model:
num_classes: ???
3 changes: 3 additions & 0 deletions src/otx/config/model/mmdet_inst_seg.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
defaults:
- default

_target_: otx.core.model.module.instance_segmentation.OTXInstanceSegLitModule

optimizer:
Expand Down
1 change: 1 addition & 0 deletions src/otx/config/model/mmseg.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ scheduler:
otx_model:
_target_: otx.core.model.entity.segmentation.MMSegCompatibleModel
config: ???
num_classes: ???

# compile model for faster training with pytorch 2.0
torch_compile: false
26 changes: 19 additions & 7 deletions src/otx/core/data/dataset/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,27 @@


@dataclass
class DataMetaInfo:
"""Meta information of each subset datasets."""
class LabelInfo:
"""Object to represent label information."""

class_names: list[str]
label_names: list[str]

@property
def num_classes(self) -> int:
"""Return number of classes."""
return len(self.class_names)
"""Return number of labels."""
return len(self.label_names)

@classmethod
def from_num_classes(cls, num_classes: int) -> LabelInfo:
"""Create this object from the number of classes.
Args:
num_classes: Number of classes
Returns:
LabelInfo(label_names=["label_0", ...])
"""
return LabelInfo(label_names=[f"label_{idx}" for idx in range(num_classes)])


class OTXDataset(Dataset, Generic[T_OTXDataEntity]):
Expand All @@ -59,8 +71,8 @@ def __init__(
self.mem_cache_img_max_size = mem_cache_img_max_size
self.max_refetch = max_refetch

self.meta_info = DataMetaInfo(
class_names=[category.name for category in self.dm_subset.categories()[AnnotationType.label]],
self.meta_info = LabelInfo(
label_names=[category.name for category in self.dm_subset.categories()[AnnotationType.label]],
)

def __len__(self) -> int:
Expand Down
6 changes: 3 additions & 3 deletions src/otx/core/data/dataset/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from datumaro.components.annotation import AnnotationType
from torch.nn import functional

from otx.core.data.dataset.base import DataMetaInfo, OTXDataset
from otx.core.data.dataset.base import LabelInfo, OTXDataset
from otx.core.data.entity.base import ImageInfo
from otx.core.data.entity.classification import (
HlabelClsBatchDataEntity,
Expand All @@ -27,7 +27,7 @@


@dataclass
class HLabelMetaInfo(DataMetaInfo):
class HLabelMetaInfo(LabelInfo):
"""Meta information of hlabel classification."""

hlabel_info: HLabelInfo
Expand Down Expand Up @@ -110,7 +110,7 @@ def __init__(self, **kwargs) -> None:

# Hlabel classification used HLabelMetaInfo to insert the HLabelInfo.
self.meta_info = HLabelMetaInfo(
class_names=[category.name for category in self.dm_categories],
label_names=[category.name for category in self.dm_categories],
hlabel_info=HLabelInfo.from_dm_label_groups(self.dm_categories),
)

Expand Down
6 changes: 3 additions & 3 deletions src/otx/core/data/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from omegaconf import DictConfig, OmegaConf
from torch.utils.data import DataLoader

from otx.core.data.dataset.base import DataMetaInfo
from otx.core.data.dataset.base import LabelInfo
from otx.core.data.factory import OTXDatasetFactory
from otx.core.data.mem_cache import (
MemCacheHandlerSingleton,
Expand Down Expand Up @@ -68,7 +68,7 @@ def __init__(
mem_size=mem_size,
)

meta_infos: list[DataMetaInfo] = []
meta_infos: list[LabelInfo] = []
for name, dm_subset in dataset.subsets().items():
if name not in config_mapping:
log.warning(f"{name} is not available. Skip it")
Expand All @@ -91,7 +91,7 @@ def __init__(

self.meta_info = next(iter(meta_infos))

def _is_meta_info_valid(self, meta_infos: list[DataMetaInfo]) -> bool:
def _is_meta_info_valid(self, meta_infos: list[LabelInfo]) -> bool:
"""Check whether there are mismatches in the metainfo for the all subsets."""
if all(meta_info == meta_infos[0] for meta_info in meta_infos):
return True
Expand Down
6 changes: 4 additions & 2 deletions src/otx/core/model/entity/action_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from otx.core.data.entity.base import OTXBatchLossEntity
from otx.core.model.entity.base import OTXModel
from otx.core.utils.build import build_mm_model, get_classification_layers
from otx.core.utils.config import inplace_num_classes

if TYPE_CHECKING:
from omegaconf import DictConfig
Expand All @@ -32,10 +33,11 @@ class MMActionCompatibleModel(OTXActionClsModel):
compatible for OTX pipelines.
"""

def __init__(self, config: DictConfig) -> None:
def __init__(self, num_classes: int, config: DictConfig) -> None:
config = inplace_num_classes(cfg=config, num_classes=num_classes)
self.config = config
self.load_from = config.pop("load_from", None)
super().__init__()
super().__init__(num_classes=num_classes)

def _create_model(self) -> nn.Module:
from mmaction.models.data_preprocessors import (
Expand Down
6 changes: 4 additions & 2 deletions src/otx/core/model/entity/action_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from otx.core.data.entity.base import OTXBatchLossEntity
from otx.core.model.entity.base import OTXModel
from otx.core.utils.build import build_mm_model, get_classification_layers
from otx.core.utils.config import inplace_num_classes

if TYPE_CHECKING:
from omegaconf import DictConfig
Expand All @@ -31,10 +32,11 @@ class MMActionCompatibleModel(OTXActionDetModel):
compatible for OTX pipelines.
"""

def __init__(self, config: DictConfig) -> None:
def __init__(self, num_classes: int, config: DictConfig) -> None:
config = inplace_num_classes(cfg=config, num_classes=num_classes)
self.config = config
self.load_from = config.pop("load_from", None)
super().__init__()
super().__init__(num_classes=num_classes)

def _create_model(self) -> nn.Module:
from mmaction.models.data_preprocessors import (
Expand Down
94 changes: 92 additions & 2 deletions src/otx/core/model/entity/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,29 +5,66 @@

from __future__ import annotations

import warnings
from abc import abstractmethod
from typing import TYPE_CHECKING, Any, Generic

from torch import nn

from otx.core.data.dataset.base import LabelInfo
from otx.core.data.entity.base import (
OTXBatchLossEntity,
T_OTXBatchDataEntity,
T_OTXBatchPredEntity,
)
from otx.core.types.export import OTXExportFormat

if TYPE_CHECKING:
from pathlib import Path

import torch


class OTXModel(nn.Module, Generic[T_OTXBatchDataEntity, T_OTXBatchPredEntity]):
"""Base class for the models used in OTX."""
"""Base class for the models used in OTX.
Args:
num_classes: Number of classes this model can predict.
"""

def __init__(self) -> None:
def __init__(self, num_classes: int) -> None:
super().__init__()

self._label_info = LabelInfo.from_num_classes(num_classes)
self.classification_layers: dict[str, dict[str, Any]] = {}
self.model = self._create_model()

@property
def label_info(self) -> LabelInfo:
"""Get this model label information."""
return self._label_info

@label_info.setter
def label_info(self, label_info: LabelInfo | list[str]) -> None:
"""Set this model label information."""
if isinstance(label_info, list):
label_info = LabelInfo(label_names=label_info)

old_num_classes = self._label_info.num_classes
new_num_classes = label_info.num_classes

if old_num_classes != new_num_classes:
msg = (
f"Given LabelInfo has the different number of classes "
f"({old_num_classes}!={new_num_classes}). "
"The model prediction layer is reset to the new number of classes "
f"(={new_num_classes})."
)
warnings.warn(msg, stacklevel=0)
self._reset_prediction_layer(num_classes=label_info.num_classes)

self._label_info = label_info

@abstractmethod
def _create_model(self) -> nn.Module:
"""Create a PyTorch model for this class."""
Expand Down Expand Up @@ -116,3 +153,56 @@ def map_class_names(src_classes: list[str], dst_classes: list[str]) -> list[int]
else:
src2dst.append(-1)
return src2dst

def export(self, output_dir: Path, export_format: OTXExportFormat) -> None:
"""Export this model to the specified output directory.
Args:
output_dir: Directory path to save exported binary files.
export_format: Format in which this `OTXModel` is exported.
"""
if export_format == OTXExportFormat.OPENVINO:
self._export_to_openvino(output_dir)
if export_format == OTXExportFormat.ONNX:
self._export_to_onnx()
if export_format == OTXExportFormat.EXPORTABLE_CODE:
self._export_to_exportable_code()

def _export_to_openvino(self, output_dir: Path) -> None:
"""Export to OpenVINO Intermediate Representation format.
Args:
output_dir: Directory path to save exported binary files
"""
raise NotImplementedError

def _export_to_onnx(self) -> None:
"""Export to ONNX format.
Args:
output_dir: Directory path to save exported binary files
"""
raise NotImplementedError

def _export_to_exportable_code(self) -> None:
"""Export to exportable code format.
Args:
output_dir: Directory path to save exported binary files
"""
raise NotImplementedError

def register_explain_hook(self) -> None:
"""Register explain hook.
TBD
"""
raise NotImplementedError

def _reset_prediction_layer(self, num_classes: int) -> None:
"""Reset its prediction layer with a given number of classes.
Args:
num_classes: Number of classes
"""
raise NotImplementedError
21 changes: 13 additions & 8 deletions src/otx/core/model/entity/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
)
from otx.core.model.entity.base import OTXModel
from otx.core.utils.build import build_mm_model, get_classification_layers
from otx.core.utils.config import inplace_num_classes

if TYPE_CHECKING:
from mmpretrain.models.utils import ClsDataPreprocessor
Expand Down Expand Up @@ -63,10 +64,11 @@ class MMPretrainMulticlassClsModel(OTXMulticlassClsModel):
compatible for OTX pipelines.
"""

def __init__(self, config: DictConfig) -> None:
def __init__(self, num_classes: int, config: DictConfig) -> None:
config = inplace_num_classes(cfg=config, num_classes=num_classes)
self.config = config
self.load_from = config.pop("load_from", None)
super().__init__()
super().__init__(num_classes=num_classes)

def _create_model(self) -> nn.Module:
model, classification_layers = _create_mmpretrain_model(self.config, self.load_from)
Expand Down Expand Up @@ -155,10 +157,11 @@ class MMPretrainMultilabelClsModel(OTXMultilabelClsModel):
compatible for OTX pipelines.
"""

def __init__(self, config: DictConfig) -> None:
def __init__(self, num_classes: int, config: DictConfig) -> None:
config = inplace_num_classes(cfg=config, num_classes=num_classes)
self.config = config
self.load_from = config.pop("load_from", None)
super().__init__()
super().__init__(num_classes=num_classes)

def _create_model(self) -> nn.Module:
model, classification_layers = _create_mmpretrain_model(self.config, self.load_from)
Expand Down Expand Up @@ -241,10 +244,11 @@ class MMPretrainHlabelClsModel(OTXHlabelClsModel):
compatible for OTX pipelines.
"""

def __init__(self, config: DictConfig) -> None:
def __init__(self, num_classes: int, config: DictConfig) -> None:
config = inplace_num_classes(cfg=config, num_classes=num_classes)
self.config = config
self.load_from = config.pop("load_from", None)
super().__init__()
super().__init__(num_classes=num_classes)

def _create_model(self) -> nn.Module:
model, classification_layers = _create_mmpretrain_model(self.config, self.load_from)
Expand Down Expand Up @@ -322,10 +326,11 @@ class OVClassificationCompatibleModel(OTXMulticlassClsModel):
and create the OTX classification model compatible for OTX testing pipeline.
"""

def __init__(self, config: DictConfig) -> None:
def __init__(self, num_classes: int, config: DictConfig) -> None:
self.model_name = config.pop("model_name")
config = inplace_num_classes(cfg=config, num_classes=num_classes)
self.config = config
super().__init__()
super().__init__(num_classes=num_classes)

def _create_model(self) -> nn.Module:
from openvino.model_api.models import ClassificationModel
Expand Down
Loading

0 comments on commit 02c3938

Please sign in to comment.