Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/v2' into vs/v2_export
Browse files Browse the repository at this point in the history
  • Loading branch information
sovrasov committed Jan 26, 2024
2 parents 74be80b + ad2a0ed commit 213e8e4
Show file tree
Hide file tree
Showing 35 changed files with 689 additions and 184 deletions.
5 changes: 5 additions & 0 deletions src/otx/algo/action_classification/x3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""X3D model implementation."""

from otx.algo.utils.mmconfig import read_mmconfig
from otx.algo.utils.support_otx_v1 import OTXv1Helper
from otx.core.model.entity.action_classification import MMActionCompatibleModel


Expand All @@ -13,3 +14,7 @@ class X3D(MMActionCompatibleModel):
def __init__(self, num_classes: int) -> None:
config = read_mmconfig("x3d")
super().__init__(num_classes=num_classes, config=config)

def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.model.") -> dict:
"""Load the previous OTX ckpt according to OTX2.0."""
return OTXv1Helper.load_action_ckpt(state_dict, add_prefix)
5 changes: 5 additions & 0 deletions src/otx/algo/action_detection/x3d_fastrcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from __future__ import annotations

from otx.algo.utils.mmconfig import read_mmconfig
from otx.algo.utils.support_otx_v1 import OTXv1Helper
from otx.core.model.entity.action_detection import MMActionCompatibleModel


Expand All @@ -15,3 +16,7 @@ def __init__(self, num_classes: int, topk: int | tuple[int]):
config = read_mmconfig("x3d_fastrcnn")
config.roi_head.bbox_head.topk = topk
super().__init__(num_classes=num_classes, config=config)

def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.model.") -> dict:
"""Load the previous OTX ckpt according to OTX2.0."""
return OTXv1Helper.load_action_ckpt(state_dict, add_prefix)
13 changes: 13 additions & 0 deletions src/otx/algo/classification/deit_tiny.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""DeitTiny model implementation."""

from otx.algo.utils.mmconfig import read_mmconfig
from otx.algo.utils.support_otx_v1 import OTXv1Helper
from otx.core.model.entity.classification import (
MMPretrainHlabelClsModel,
MMPretrainMulticlassClsModel,
Expand All @@ -20,6 +21,10 @@ def __init__(self, num_classes: int, num_multiclass_heads: int, num_multilabel_c
config.head.num_multilabel_classes = num_multilabel_classes
super().__init__(num_classes=num_classes, config=config)

def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.model.") -> dict:
"""Load the previous OTX ckpt according to OTX2.0."""
return OTXv1Helper.load_cls_effnet_b0_ckpt(state_dict, "multiclass", add_prefix)


class DeitTinyForMulticlassCls(MMPretrainMulticlassClsModel):
"""DeitTiny Model for multi-label classification task."""
Expand All @@ -28,10 +33,18 @@ def __init__(self, num_classes: int) -> None:
config = read_mmconfig("deit_tiny", subdir_name="multiclass_classification")
super().__init__(num_classes=num_classes, config=config)

def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.model.") -> dict:
"""Load the previous OTX ckpt according to OTX2.0."""
return OTXv1Helper.load_cls_effnet_b0_ckpt(state_dict, "multiclass", add_prefix)


class DeitTinyForMultilabelCls(MMPretrainMultilabelClsModel):
"""DeitTiny Model for multi-class classification task."""

def __init__(self, num_classes: int) -> None:
config = read_mmconfig("deit_tiny", subdir_name="multilabel_classification")
super().__init__(num_classes=num_classes, config=config)

def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.model.") -> dict:
"""Load the previous OTX ckpt according to OTX2.0."""
return OTXv1Helper.load_cls_effnet_b0_ckpt(state_dict, "multiclass", add_prefix)
14 changes: 13 additions & 1 deletion src/otx/algo/classification/efficientnet_b0.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
# SPDX-License-Identifier: Apache-2.0
#
"""EfficientNetB0 model implementation."""

from otx.algo.utils.mmconfig import read_mmconfig
from otx.algo.utils.support_otx_v1 import OTXv1Helper
from otx.core.model.entity.classification import (
MMPretrainHlabelClsModel,
MMPretrainMulticlassClsModel,
Expand All @@ -20,6 +20,10 @@ def __init__(self, num_classes: int, num_multiclass_heads: int, num_multilabel_c
config.head.num_multilabel_classes = num_multilabel_classes
super().__init__(num_classes=num_classes, config=config)

def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.model.") -> dict:
"""Load the previous OTX ckpt according to OTX2.0."""
return OTXv1Helper.load_cls_effnet_b0_ckpt(state_dict, "hlabel", add_prefix)


class EfficientNetB0ForMulticlassCls(MMPretrainMulticlassClsModel):
"""EfficientNetB0 Model for multi-label classification task."""
Expand All @@ -29,10 +33,18 @@ def __init__(self, num_classes: int, light: bool = False) -> None:
config = read_mmconfig(model_name=model_name, subdir_name="multiclass_classification")
super().__init__(num_classes=num_classes, config=config)

def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.model.") -> dict:
"""Load the previous OTX ckpt according to OTX2.0."""
return OTXv1Helper.load_cls_effnet_b0_ckpt(state_dict, "multiclass", add_prefix)


class EfficientNetB0ForMultilabelCls(MMPretrainMultilabelClsModel):
"""EfficientNetB0 Model for multi-class classification task."""

def __init__(self, num_classes: int) -> None:
config = read_mmconfig(model_name="efficientnet_b0_light", subdir_name="multilabel_classification")
super().__init__(num_classes=num_classes, config=config)

def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.model.") -> dict:
"""Load the previous OTX ckpt according to OTX2.0."""
return OTXv1Helper.load_cls_effnet_b0_ckpt(state_dict, "multilabel", add_prefix)
14 changes: 13 additions & 1 deletion src/otx/algo/classification/efficientnet_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
# SPDX-License-Identifier: Apache-2.0
#
"""EfficientNetV2 model implementation."""

from otx.algo.utils.mmconfig import read_mmconfig
from otx.algo.utils.support_otx_v1 import OTXv1Helper
from otx.core.model.entity.classification import (
MMPretrainHlabelClsModel,
MMPretrainMulticlassClsModel,
Expand All @@ -20,6 +20,10 @@ def __init__(self, num_classes: int, num_multiclass_heads: int, num_multilabel_c
config.head.num_multilabel_classes = num_multilabel_classes
super().__init__(num_classes=num_classes, config=config)

def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.model.") -> dict:
"""Load the previous OTX ckpt according to OTX2.0."""
return OTXv1Helper.load_cls_effnet_v2_ckpt(state_dict, "hlabel", add_prefix)


class EfficientNetV2ForMulticlassCls(MMPretrainMulticlassClsModel):
"""EfficientNetV2 Model for multi-label classification task."""
Expand All @@ -29,10 +33,18 @@ def __init__(self, num_classes: int, light: bool = False) -> None:
config = read_mmconfig(model_name=model_name, subdir_name="multiclass_classification")
super().__init__(num_classes=num_classes, config=config)

def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.model.") -> dict:
"""Load the previous OTX ckpt according to OTX2.0."""
return OTXv1Helper.load_cls_effnet_v2_ckpt(state_dict, "multiclass", add_prefix)


class EfficientNetV2ForMultilabelCls(MMPretrainMultilabelClsModel):
"""EfficientNetV2 Model for multi-class classification task."""

def __init__(self, num_classes: int) -> None:
config = read_mmconfig("efficientnet_v2_light", subdir_name="multilabel_classification")
super().__init__(num_classes=num_classes, config=config)

def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.model.") -> dict:
"""Load the previous OTX ckpt according to OTX2.0."""
return OTXv1Helper.load_cls_effnet_v2_ckpt(state_dict, "multilabel", add_prefix)
14 changes: 13 additions & 1 deletion src/otx/algo/classification/mobilenet_v3_large.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
# SPDX-License-Identifier: Apache-2.0
#
"""MobileNetV3 model implementation."""

from otx.algo.utils.mmconfig import read_mmconfig
from otx.algo.utils.support_otx_v1 import OTXv1Helper
from otx.core.model.entity.classification import (
MMPretrainHlabelClsModel,
MMPretrainMulticlassClsModel,
Expand All @@ -24,6 +24,10 @@ def _configure_export_parameters(self) -> None:
super()._configure_export_parameters()
self.export_params["via_onnx"] = True

def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.model.") -> dict:
"""Load the previous OTX ckpt according to OTX2.0."""
return OTXv1Helper.load_cls_mobilenet_v3_ckpt(state_dict, "hlabel", add_prefix)


class MobileNetV3ForMulticlassCls(MMPretrainMulticlassClsModel):
"""MobileNetV3 Model for multi-label classification task."""
Expand All @@ -37,6 +41,10 @@ def _configure_export_parameters(self) -> None:
super()._configure_export_parameters()
self.export_params["via_onnx"] = True

def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.model.") -> dict:
"""Load the previous OTX ckpt according to OTX2.0."""
return OTXv1Helper.load_cls_mobilenet_v3_ckpt(state_dict, "multiclass", add_prefix)


class MobileNetV3ForMultilabelCls(MMPretrainMultilabelClsModel):
"""MobileNetV3 Model for multi-class classification task."""
Expand All @@ -48,3 +56,7 @@ def __init__(self, num_classes: int) -> None:
def _configure_export_parameters(self) -> None:
super()._configure_export_parameters()
self.export_params["via_onnx"] = True

def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.model.") -> dict:
"""Load the previous OTX ckpt according to OTX2.0."""
return OTXv1Helper.load_cls_mobilenet_v3_ckpt(state_dict, "multilabel", add_prefix)
5 changes: 5 additions & 0 deletions src/otx/algo/detection/atss.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Literal

from otx.algo.utils.mmconfig import read_mmconfig
from otx.algo.utils.support_otx_v1 import OTXv1Helper
from otx.core.model.entity.detection import MMDetCompatibleModel


Expand All @@ -16,3 +17,7 @@ def __init__(self, num_classes: int, variant: Literal["mobilenetv2", "r50_fpn",
model_name = f"atss_{variant}"
config = read_mmconfig(model_name=model_name)
super().__init__(num_classes=num_classes, config=config)

def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.model.") -> dict:
"""Load the previous OTX ckpt according to OTX2.0."""
return OTXv1Helper.load_det_ckpt(state_dict, add_prefix)
5 changes: 5 additions & 0 deletions src/otx/algo/detection/rtmdet.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Literal

from otx.algo.utils.mmconfig import read_mmconfig
from otx.algo.utils.support_otx_v1 import OTXv1Helper
from otx.core.model.entity.detection import MMDetCompatibleModel


Expand All @@ -16,3 +17,7 @@ def __init__(self, num_classes: int, variant: Literal["tiny"]) -> None:
model_name = f"rtmdet_{variant}"
config = read_mmconfig(model_name=model_name)
super().__init__(num_classes=num_classes, config=config)

def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.model.") -> dict:
"""Load the previous OTX ckpt according to OTX2.0."""
return OTXv1Helper.load_det_ckpt(state_dict, add_prefix)
5 changes: 5 additions & 0 deletions src/otx/algo/detection/ssd.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import TYPE_CHECKING, Literal

from otx.algo.utils.mmconfig import read_mmconfig
from otx.algo.utils.support_otx_v1 import OTXv1Helper
from otx.core.model.entity.detection import MMDetCompatibleModel
from otx.core.utils.build import build_mm_model, modify_num_classes

Expand Down Expand Up @@ -123,3 +124,7 @@ def load_state_dict_pre_hook(self, state_dict: dict[str, torch.Tensor], prefix:

# Replace checkpoint weight by mixed weights
state_dict[prefix + param_name] = model_param

def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.model.") -> dict:
"""Load the previous OTX ckpt according to OTX2.0."""
return OTXv1Helper.load_det_ckpt(state_dict, add_prefix)
5 changes: 5 additions & 0 deletions src/otx/algo/detection/yolox.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Literal

from otx.algo.utils.mmconfig import read_mmconfig
from otx.algo.utils.support_otx_v1 import OTXv1Helper
from otx.core.model.entity.detection import MMDetCompatibleModel


Expand All @@ -16,3 +17,7 @@ def __init__(self, num_classes: int, variant: Literal["l", "s", "tiny", "x"]) ->
model_name = f"yolox_{variant}"
config = read_mmconfig(model_name=model_name)
super().__init__(num_classes=num_classes, config=config)

def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.model.") -> dict:
"""Load the previous OTX ckpt according to OTX2.0."""
return OTXv1Helper.load_det_ckpt(state_dict, add_prefix)
5 changes: 5 additions & 0 deletions src/otx/algo/instance_segmentation/maskrcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Literal

from otx.algo.utils.mmconfig import read_mmconfig
from otx.algo.utils.support_otx_v1 import OTXv1Helper
from otx.core.model.entity.instance_segmentation import MMDetInstanceSegCompatibleModel


Expand All @@ -16,3 +17,7 @@ def __init__(self, num_classes: int, variant: Literal["efficientnetb2b", "r50",
model_name = f"maskrcnn_{variant}"
config = read_mmconfig(model_name=model_name)
super().__init__(num_classes=num_classes, config=config)

def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.model.") -> dict:
"""Load the previous OTX ckpt according to OTX2.0."""
return OTXv1Helper.load_iseg_ckpt(state_dict, add_prefix)
5 changes: 5 additions & 0 deletions src/otx/algo/segmentation/litehrnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from torch.onnx import OperatorExportTypes

from otx.algo.utils.mmconfig import read_mmconfig
from otx.algo.utils.support_otx_v1 import OTXv1Helper
from otx.core.model.entity.segmentation import MMSegCompatibleModel


Expand All @@ -24,3 +25,7 @@ def _configure_export_parameters(self) -> None:
self.export_params["onnx_export_configuration"] = {
"operator_export_type": OperatorExportTypes.ONNX_ATEN_FALLBACK,
}

def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.model.") -> dict:
"""Load the previous OTX ckpt according to OTX2.0."""
return OTXv1Helper.load_seg_lite_hrnet_ckpt(state_dict, add_prefix)
5 changes: 5 additions & 0 deletions src/otx/algo/segmentation/segnext.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Literal

from otx.algo.utils.mmconfig import read_mmconfig
from otx.algo.utils.support_otx_v1 import OTXv1Helper
from otx.core.model.entity.segmentation import MMSegCompatibleModel


Expand All @@ -16,3 +17,7 @@ def __init__(self, num_classes: int, variant: Literal["b", "s", "t"]) -> None:
model_name = f"segnext_{variant}"
config = read_mmconfig(model_name=model_name)
super().__init__(num_classes=num_classes, config=config)

def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.model.") -> dict:
"""Load the previous OTX ckpt according to OTX2.0."""
return OTXv1Helper.load_seg_segnext_ckpt(state_dict, add_prefix)
Loading

0 comments on commit 213e8e4

Please sign in to comment.