Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pre-work for the design change: create own python class for every model #2775

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/otx/algo/classification/otx_dino_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ def forward(self, imgs: torch.Tensor, labels: torch.Tensor = None) -> torch.Tens
class DINOv2RegisterClassifier(OTXMulticlassClsModel):
"""DINO-v2 Classification Model with register."""

def __init__(self, config: DictConfig) -> None:
def __init__(self, num_classes: int, config: DictConfig) -> None:
self.config = config
super().__init__() # create the model
super().__init__(num_classes=num_classes) # create the model

def _create_model(self) -> nn.Module:
"""Create the model."""
Expand Down
3 changes: 3 additions & 0 deletions src/otx/config/model/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,6 @@ scheduler:
mode: min
factor: 0.1
patience: 10

otx_model:
num_classes: ???
3 changes: 3 additions & 0 deletions src/otx/config/model/mmdet_inst_seg.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
defaults:
- default

_target_: otx.core.model.module.instance_segmentation.OTXInstanceSegLitModule

optimizer:
Expand Down
1 change: 1 addition & 0 deletions src/otx/config/model/mmseg.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ scheduler:
otx_model:
_target_: otx.core.model.entity.segmentation.MMSegCompatibleModel
config: ???
num_classes: ???

# compile model for faster training with pytorch 2.0
torch_compile: false
26 changes: 19 additions & 7 deletions src/otx/core/data/dataset/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,27 @@


@dataclass
class DataMetaInfo:
"""Meta information of each subset datasets."""
class LabelInfo:
"""Object to represent label information."""
jaegukhyun marked this conversation as resolved.
Show resolved Hide resolved

class_names: list[str]
label_names: list[str]

@property
def num_classes(self) -> int:
"""Return number of classes."""
return len(self.class_names)
"""Return number of labels."""
return len(self.label_names)

@classmethod
def from_num_classes(cls, num_classes: int) -> LabelInfo:
"""Create this object from the number of classes.

Args:
num_classes: Number of classes

Returns:
LabelInfo(label_names=["label_0", ...])
"""
return LabelInfo(label_names=[f"label_{idx}" for idx in range(num_classes)])


class OTXDataset(Dataset, Generic[T_OTXDataEntity]):
Expand All @@ -59,8 +71,8 @@ def __init__(
self.mem_cache_img_max_size = mem_cache_img_max_size
self.max_refetch = max_refetch

self.meta_info = DataMetaInfo(
class_names=[category.name for category in self.dm_subset.categories()[AnnotationType.label]],
self.meta_info = LabelInfo(
label_names=[category.name for category in self.dm_subset.categories()[AnnotationType.label]],
)

def __len__(self) -> int:
Expand Down
6 changes: 3 additions & 3 deletions src/otx/core/data/dataset/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from datumaro.components.annotation import AnnotationType
from torch.nn import functional

from otx.core.data.dataset.base import DataMetaInfo, OTXDataset
from otx.core.data.dataset.base import LabelInfo, OTXDataset
from otx.core.data.entity.base import ImageInfo
from otx.core.data.entity.classification import (
HlabelClsBatchDataEntity,
Expand All @@ -27,7 +27,7 @@


@dataclass
class HLabelMetaInfo(DataMetaInfo):
class HLabelMetaInfo(LabelInfo):
"""Meta information of hlabel classification."""

hlabel_info: HLabelInfo
Expand Down Expand Up @@ -110,7 +110,7 @@ def __init__(self, **kwargs) -> None:

# Hlabel classification used HLabelMetaInfo to insert the HLabelInfo.
self.meta_info = HLabelMetaInfo(
class_names=[category.name for category in self.dm_categories],
label_names=[category.name for category in self.dm_categories],
hlabel_info=HLabelInfo.from_dm_label_groups(self.dm_categories),
)

Expand Down
6 changes: 3 additions & 3 deletions src/otx/core/data/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from omegaconf import DictConfig, OmegaConf
from torch.utils.data import DataLoader

from otx.core.data.dataset.base import DataMetaInfo
from otx.core.data.dataset.base import LabelInfo
from otx.core.data.factory import OTXDatasetFactory
from otx.core.data.mem_cache import (
MemCacheHandlerSingleton,
Expand Down Expand Up @@ -68,7 +68,7 @@ def __init__(
mem_size=mem_size,
)

meta_infos: list[DataMetaInfo] = []
meta_infos: list[LabelInfo] = []
for name, dm_subset in dataset.subsets().items():
if name not in config_mapping:
log.warning(f"{name} is not available. Skip it")
Expand All @@ -91,7 +91,7 @@ def __init__(

self.meta_info = next(iter(meta_infos))

def _is_meta_info_valid(self, meta_infos: list[DataMetaInfo]) -> bool:
def _is_meta_info_valid(self, meta_infos: list[LabelInfo]) -> bool:
"""Check whether there are mismatches in the metainfo for the all subsets."""
if all(meta_info == meta_infos[0] for meta_info in meta_infos):
return True
Expand Down
6 changes: 4 additions & 2 deletions src/otx/core/model/entity/action_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from otx.core.data.entity.base import OTXBatchLossEntity
from otx.core.model.entity.base import OTXModel
from otx.core.utils.build import build_mm_model, get_classification_layers
from otx.core.utils.config import inplace_num_classes

if TYPE_CHECKING:
from omegaconf import DictConfig
Expand All @@ -32,10 +33,11 @@ class MMActionCompatibleModel(OTXActionClsModel):
compatible for OTX pipelines.
"""

def __init__(self, config: DictConfig) -> None:
def __init__(self, num_classes: int, config: DictConfig) -> None:
config = inplace_num_classes(cfg=config, num_classes=num_classes)
self.config = config
self.load_from = config.pop("load_from", None)
super().__init__()
super().__init__(num_classes=num_classes)

def _create_model(self) -> nn.Module:
from mmaction.models.data_preprocessors import (
Expand Down
6 changes: 4 additions & 2 deletions src/otx/core/model/entity/action_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from otx.core.data.entity.base import OTXBatchLossEntity
from otx.core.model.entity.base import OTXModel
from otx.core.utils.build import build_mm_model, get_classification_layers
from otx.core.utils.config import inplace_num_classes

if TYPE_CHECKING:
from omegaconf import DictConfig
Expand All @@ -31,10 +32,11 @@ class MMActionCompatibleModel(OTXActionDetModel):
compatible for OTX pipelines.
"""

def __init__(self, config: DictConfig) -> None:
def __init__(self, num_classes: int, config: DictConfig) -> None:
config = inplace_num_classes(cfg=config, num_classes=num_classes)
self.config = config
self.load_from = config.pop("load_from", None)
super().__init__()
super().__init__(num_classes=num_classes)

def _create_model(self) -> nn.Module:
from mmaction.models.data_preprocessors import (
Expand Down
94 changes: 92 additions & 2 deletions src/otx/core/model/entity/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,29 +5,66 @@

from __future__ import annotations

import warnings
from abc import abstractmethod
from typing import TYPE_CHECKING, Any, Generic

from torch import nn

from otx.core.data.dataset.base import LabelInfo
from otx.core.data.entity.base import (
OTXBatchLossEntity,
T_OTXBatchDataEntity,
T_OTXBatchPredEntity,
)
from otx.core.types.export import OTXExportFormat

if TYPE_CHECKING:
from pathlib import Path

import torch


class OTXModel(nn.Module, Generic[T_OTXBatchDataEntity, T_OTXBatchPredEntity]):
"""Base class for the models used in OTX."""
"""Base class for the models used in OTX.

Args:
num_classes: Number of classes this model can predict.
"""

def __init__(self) -> None:
def __init__(self, num_classes: int) -> None:
super().__init__()

self._label_info = LabelInfo.from_num_classes(num_classes)
self.classification_layers: dict[str, dict[str, Any]] = {}
self.model = self._create_model()

@property
def label_info(self) -> LabelInfo:
"""Get this model label information."""
return self._label_info

@label_info.setter
def label_info(self, label_info: LabelInfo | list[str]) -> None:
"""Set this model label information."""
if isinstance(label_info, list):
label_info = LabelInfo(label_names=label_info)

old_num_classes = self._label_info.num_classes
new_num_classes = label_info.num_classes

if old_num_classes != new_num_classes:
msg = (
f"Given LabelInfo has the different number of classes "
f"({old_num_classes}!={new_num_classes}). "
"The model prediction layer is reset to the new number of classes "
f"(={new_num_classes})."
)
warnings.warn(msg, stacklevel=0)
self._reset_prediction_layer(num_classes=label_info.num_classes)
jaegukhyun marked this conversation as resolved.
Show resolved Hide resolved

self._label_info = label_info

@abstractmethod
def _create_model(self) -> nn.Module:
"""Create a PyTorch model for this class."""
Expand Down Expand Up @@ -116,3 +153,56 @@ def map_class_names(src_classes: list[str], dst_classes: list[str]) -> list[int]
else:
src2dst.append(-1)
return src2dst

def export(self, output_dir: Path, export_format: OTXExportFormat) -> None:
"""Export this model to the specified output directory.

Args:
output_dir: Directory path to save exported binary files.
export_format: Format in which this `OTXModel` is exported.
"""
if export_format == OTXExportFormat.OPENVINO:
self._export_to_openvino(output_dir)
if export_format == OTXExportFormat.ONNX:
self._export_to_onnx()
if export_format == OTXExportFormat.EXPORTABLE_CODE:
self._export_to_exportable_code()

def _export_to_openvino(self, output_dir: Path) -> None:
"""Export to OpenVINO Intermediate Representation format.

Args:
output_dir: Directory path to save exported binary files
"""
raise NotImplementedError

def _export_to_onnx(self) -> None:
"""Export to ONNX format.

Args:
output_dir: Directory path to save exported binary files
"""
raise NotImplementedError

def _export_to_exportable_code(self) -> None:
"""Export to exportable code format.

Args:
output_dir: Directory path to save exported binary files
"""
raise NotImplementedError

def register_explain_hook(self) -> None:
"""Register explain hook.

TBD
"""
raise NotImplementedError

def _reset_prediction_layer(self, num_classes: int) -> None:
"""Reset its prediction layer with a given number of classes.

Args:
num_classes: Number of classes
"""
raise NotImplementedError
21 changes: 13 additions & 8 deletions src/otx/core/model/entity/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
)
from otx.core.model.entity.base import OTXModel
from otx.core.utils.build import build_mm_model, get_classification_layers
from otx.core.utils.config import inplace_num_classes

if TYPE_CHECKING:
from mmpretrain.models.utils import ClsDataPreprocessor
Expand Down Expand Up @@ -63,10 +64,11 @@ class MMPretrainMulticlassClsModel(OTXMulticlassClsModel):
compatible for OTX pipelines.
"""

def __init__(self, config: DictConfig) -> None:
def __init__(self, num_classes: int, config: DictConfig) -> None:
config = inplace_num_classes(cfg=config, num_classes=num_classes)
self.config = config
self.load_from = config.pop("load_from", None)
super().__init__()
super().__init__(num_classes=num_classes)

def _create_model(self) -> nn.Module:
model, classification_layers = _create_mmpretrain_model(self.config, self.load_from)
Expand Down Expand Up @@ -155,10 +157,11 @@ class MMPretrainMultilabelClsModel(OTXMultilabelClsModel):
compatible for OTX pipelines.
"""

def __init__(self, config: DictConfig) -> None:
def __init__(self, num_classes: int, config: DictConfig) -> None:
config = inplace_num_classes(cfg=config, num_classes=num_classes)
self.config = config
self.load_from = config.pop("load_from", None)
super().__init__()
super().__init__(num_classes=num_classes)

def _create_model(self) -> nn.Module:
model, classification_layers = _create_mmpretrain_model(self.config, self.load_from)
Expand Down Expand Up @@ -241,10 +244,11 @@ class MMPretrainHlabelClsModel(OTXHlabelClsModel):
compatible for OTX pipelines.
"""

def __init__(self, config: DictConfig) -> None:
def __init__(self, num_classes: int, config: DictConfig) -> None:
config = inplace_num_classes(cfg=config, num_classes=num_classes)
self.config = config
self.load_from = config.pop("load_from", None)
super().__init__()
super().__init__(num_classes=num_classes)

def _create_model(self) -> nn.Module:
model, classification_layers = _create_mmpretrain_model(self.config, self.load_from)
Expand Down Expand Up @@ -322,10 +326,11 @@ class OVClassificationCompatibleModel(OTXMulticlassClsModel):
and create the OTX classification model compatible for OTX testing pipeline.
"""

def __init__(self, config: DictConfig) -> None:
def __init__(self, num_classes: int, config: DictConfig) -> None:
self.model_name = config.pop("model_name")
config = inplace_num_classes(cfg=config, num_classes=num_classes)
self.config = config
super().__init__()
super().__init__(num_classes=num_classes)

def _create_model(self) -> nn.Module:
from openvino.model_api.models import ClassificationModel
Expand Down
Loading