Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added segmentation configuration for export to align new interface #2808

Merged
merged 10 commits into from
Jan 17, 2024
9 changes: 9 additions & 0 deletions src/otx/algo/segmentation/dino_v2_seg.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,12 @@ def __init__(self, num_classes: int) -> None:
model_name = "dino_v2_seg"
config = read_mmconfig(model_name=model_name)
super().__init__(num_classes=num_classes, config=config)

def export(self, *args) -> None:
"""Export method for DinoV2Seg.

Model doesn't support export for now due to unsupported operations from xformers.
This method will raise an error.
"""
msg = "{model_name} cannot be exported. It is not supported."
raise RuntimeError(msg)
9 changes: 9 additions & 0 deletions src/otx/algo/segmentation/litehrnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

from typing import Literal

from torch.onnx import OperatorExportTypes

from otx.algo.utils.mmconfig import read_mmconfig
from otx.core.model.entity.segmentation import MMSegCompatibleModel

Expand All @@ -16,3 +18,10 @@ def __init__(self, num_classes: int, variant: Literal["18", "s", "x"]) -> None:
model_name = f"litehrnet_{variant}"
config = read_mmconfig(model_name=model_name)
super().__init__(num_classes=num_classes, config=config)

def _configure_export_parameters(self) -> None:
super()._configure_export_parameters()
self.export_params["via_onnx"] = True
self.export_params["onnx_export_configuration"] = {
"operator_export_type": OperatorExportTypes.ONNX_ATEN_FALLBACK,
}
2 changes: 1 addition & 1 deletion src/otx/algo/segmentation/mmconfigs/segnext_b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ decode_head:
requires_grad: true
type: GN
num_classes: 150
type: LightHamHead
type: CustomLightHamHead
pretrained: null
test_cfg:
mode: whole
Expand Down
2 changes: 1 addition & 1 deletion src/otx/algo/segmentation/mmconfigs/segnext_s.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ decode_head:
requires_grad: true
type: GN
num_classes: 4
type: LightHamHead
type: CustomLightHamHead
pretrained: null
test_cfg:
mode: whole
Expand Down
2 changes: 1 addition & 1 deletion src/otx/algo/segmentation/mmconfigs/segnext_t.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ decode_head:
requires_grad: true
type: GN
num_classes: 150
type: LightHamHead
type: CustomLightHamHead
pretrained: null
test_cfg:
mode: whole
Expand Down
4 changes: 4 additions & 0 deletions src/otx/algo/segmentation/segnext.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,7 @@ def __init__(self, num_classes: int, variant: Literal["b", "s", "t"]) -> None:
model_name = f"segnext_{variant}"
config = read_mmconfig(model_name=model_name)
super().__init__(num_classes=num_classes, config=config)

def _configure_export_parameters(self) -> None:
super()._configure_export_parameters()
self.export_params["via_onnx"] = True
7 changes: 4 additions & 3 deletions src/otx/core/model/entity/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def __init__(self, num_classes: int, config: DictConfig) -> None:
self.config = config
self.export_params = _get_export_params_from_cls_mmconfig(config)
self.load_from = config.pop("load_from", None)
self.image_size = (224, 224)
super().__init__(num_classes=num_classes)

def _create_model(self) -> nn.Module:
Expand Down Expand Up @@ -164,7 +165,7 @@ def _configure_export_parameters(self) -> None:
self.export_params["pad_value"] = 0
self.export_params["swap_rgb"] = False
self.export_params["via_onnx"] = False
self.export_params["input_size"] = (1, 3, 224, 224)
self.export_params["input_size"] = (1, 3, *self.image_size)
self.export_params["onnx_export_configuration"] = None

def _create_exporter(
Expand Down Expand Up @@ -284,7 +285,7 @@ def _configure_export_parameters(self) -> None:
self.export_params["pad_value"] = 0
self.export_params["swap_rgb"] = False
self.export_params["via_onnx"] = False
self.export_params["input_size"] = (1, 3, 224, 224)
self.export_params["input_size"] = (1, 3, *self.image_size)
self.export_params["onnx_export_configuration"] = None

def _create_exporter(
Expand Down Expand Up @@ -404,7 +405,7 @@ def _configure_export_parameters(self) -> None:
self.export_params["pad_value"] = 0
self.export_params["swap_rgb"] = False
self.export_params["via_onnx"] = False
self.export_params["input_size"] = (1, 3, 224, 224)
self.export_params["input_size"] = (1, 3, *self.image_size)
self.export_params["onnx_export_configuration"] = None

def _create_exporter(
Expand Down
5 changes: 3 additions & 2 deletions src/otx/core/model/entity/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def _generate_model_metadata(
metadata[("model_info", "task_type")] = "segmentation"
metadata[("model_info", "return_soft_prediction")] = str(True)
metadata[("model_info", "soft_threshold")] = str(0.5)
metadata[("model_info", "blur_strength")] = str(1)
metadata[("model_info", "blur_strength")] = str(-1)

return metadata

Expand All @@ -60,6 +60,7 @@ def __init__(self, num_classes: int, config: DictConfig) -> None:
self.config = config
self.export_params = _get_export_params_from_seg_mmconfig(config)
self.load_from = self.config.pop("load_from", None)
self.image_size = (544, 544)
super().__init__(num_classes=num_classes)

def _create_model(self) -> nn.Module:
Expand Down Expand Up @@ -151,7 +152,7 @@ def _configure_export_parameters(self) -> None:
self.export_params["pad_value"] = 0
self.export_params["swap_rgb"] = False
self.export_params["via_onnx"] = False
self.export_params["input_size"] = (1, 3, 512, 512)
self.export_params["input_size"] = (1, 3, *self.image_size)
self.export_params["onnx_export_configuration"] = None

def _create_exporter(
Expand Down
1 change: 1 addition & 0 deletions src/otx/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,7 @@ def export(self, output_dir: Path, cfg: ExportConfig) -> None:
output_dir (Path): Directory path to save exported binary files.
"""
if self.checkpoint is not None:
self.model.eval()
lit_module = self._build_lightning_module(
model=self.model,
optimizer=self.optimizer,
Expand Down