From 8113046a388450c6c0dc87e26fa572623a465749 Mon Sep 17 00:00:00 2001 From: "Shin, Eunwoo" Date: Mon, 5 Aug 2024 17:09:02 +0900 Subject: [PATCH] update interface --- src/otx/algo/detection/atss.py | 1 + .../instance_segmentation/heads/custom_roi_head.py | 1 - src/otx/cli/cli.py | 12 +++++------- src/otx/core/data/module.py | 5 +++++ src/otx/engine/engine.py | 11 +---------- src/otx/engine/utils/auto_configurator.py | 9 +++++++-- 6 files changed, 19 insertions(+), 20 deletions(-) diff --git a/src/otx/algo/detection/atss.py b/src/otx/algo/detection/atss.py index e1d5c5842eb..02d0a5d455d 100644 --- a/src/otx/algo/detection/atss.py +++ b/src/otx/algo/detection/atss.py @@ -56,6 +56,7 @@ def __init__( torch_compile=torch_compile, tile_config=tile_config, ) + breakpoint() self.tile_image_size = tile_image_size @property diff --git a/src/otx/algo/instance_segmentation/heads/custom_roi_head.py b/src/otx/algo/instance_segmentation/heads/custom_roi_head.py index 4536956b873..360027b1376 100644 --- a/src/otx/algo/instance_segmentation/heads/custom_roi_head.py +++ b/src/otx/algo/instance_segmentation/heads/custom_roi_head.py @@ -548,7 +548,6 @@ def bbox_loss(self, x: tuple[Tensor], sampling_results: list[SamplingResult], ba class CustomConvFCBBoxHead(Shared2FCBBoxHead, ClassIncrementalMixin): """CustomConvFCBBoxHead class for OTX.""" - # checked def loss_and_target( self, diff --git a/src/otx/cli/cli.py b/src/otx/cli/cli.py index 6759d97adbe..fcec67ec2d7 100644 --- a/src/otx/cli/cli.py +++ b/src/otx/cli/cli.py @@ -331,18 +331,16 @@ def instantiate_classes(self, instantiate_engine: bool = True) -> None: # For num_classes update, Model and Metric are instantiated separately. model_config = self.config[self.subcommand].pop("model") - input_size = self.config["train"]["engine"].get("input_size") - if input_size is not None: - if isinstance(input_size, int): - input_size = (input_size, input_size) - self.config["train"]["data"]["input_size"] = input_size - model_config["init_args"]["input_size"] = tuple(model_config["init_args"]["input_size"][:-2]) + input_size - # Instantiate the things that don't need to special handling self.config_init = self.parser.instantiate_classes(self.config) self.workspace = self.get_config_value(self.config_init, "workspace") self.datamodule = self.get_config_value(self.config_init, "data") + if (input_size := self.datamodule.input_size) is not None: + if isinstance(input_size, int): + input_size = (input_size, input_size) + model_config["init_args"]["input_size"] = tuple(model_config["init_args"]["input_size"][:-2]) + input_size + # Instantiate the model and needed components self.model = self.instantiate_model(model_config=model_config) diff --git a/src/otx/core/data/module.py b/src/otx/core/data/module.py index a1a5cdced8a..d371bb5320d 100644 --- a/src/otx/core/data/module.py +++ b/src/otx/core/data/module.py @@ -63,6 +63,7 @@ def __init__( auto_num_workers: bool = False, device: DeviceType = DeviceType.auto, input_size: int | tuple[int, int] | None = None, + adaptive_input_size: bool = False, ) -> None: """Constructor.""" super().__init__() @@ -70,10 +71,14 @@ def __init__( self.data_format = data_format self.data_root = data_root + if adaptive_input_size: + print("adaptive_input_size works") + if input_size is not None: for subset_cfg in [train_subset, val_subset, test_subset, unlabeled_subset]: if subset_cfg.input_size is None: subset_cfg.input_size = input_size + self.input_size = input_size self.train_subset = train_subset self.val_subset = val_subset diff --git a/src/otx/engine/engine.py b/src/otx/engine/engine.py index edd0d6c063c..ee5ff4dce35 100644 --- a/src/otx/engine/engine.py +++ b/src/otx/engine/engine.py @@ -122,7 +122,6 @@ def __init__( checkpoint: PathLike | None = None, device: DeviceType = DeviceType.auto, num_devices: int = 1, - input_size: Sequence[int] | int | None = None, **kwargs, ): """Initializes the OTX Engine. @@ -147,17 +146,8 @@ def __init__( data_root=data_root, task=datamodule.task if datamodule is not None else task, model_name=None if isinstance(model, OTXModel) else model, - input_size=input_size, ) - if input_size is not None: - if isinstance(datamodule, OTXDataModule) and datamodule.input_size != input_size: - msg = "Data module is already initialized. Input size will be ignored to data module." - logging.warning(msg) - if isinstance(model, OTXModel) and model.input_size != input_size: - msg = "Model is already initialized. Input size will be ignored to model." - logging.warning(msg) - self._datamodule: OTXDataModule | None = ( datamodule if datamodule is not None else self._auto_configurator.get_datamodule() ) @@ -169,6 +159,7 @@ def __init__( if isinstance(model, OTXModel) else self._auto_configurator.get_model( label_info=self._datamodule.label_info if self._datamodule is not None else None, + input_size=self._datamodule.input_size, ) ) diff --git a/src/otx/engine/utils/auto_configurator.py b/src/otx/engine/utils/auto_configurator.py index 26992720134..65e8f8bf2e7 100644 --- a/src/otx/engine/utils/auto_configurator.py +++ b/src/otx/engine/utils/auto_configurator.py @@ -65,7 +65,7 @@ ], "common_semantic_segmentation_with_subset_dirs": [OTXTaskType.SEMANTIC_SEGMENTATION], "kinetics": [OTXTaskType.ACTION_CLASSIFICATION], - "mvtec_classification": [OTXTaskType.ANOMALY_CLASSIFICATION, OTXTaskType.ANOMALY_DETECTION, OTXTaskType.ANOMALY_SEGMENTATION], + "mvtec": [OTXTaskType.ANOMALY_CLASSIFICATION, OTXTaskType.ANOMALY_DETECTION, OTXTaskType.ANOMALY_SEGMENTATION], } OVMODEL_PER_TASK = { @@ -245,7 +245,7 @@ def get_datamodule(self) -> OTXDataModule | None: **data_config, ) - def get_model(self, model_name: str | None = None, label_info: LabelInfoTypes | None = None) -> OTXModel: + def get_model(self, model_name: str | None = None, label_info: LabelInfoTypes | None = None, input_size: Sequence[int] | None = None) -> OTXModel: """Retrieves the OTXModel instance based on the provided model name and meta information. Args: @@ -278,6 +278,11 @@ def get_model(self, model_name: str | None = None, label_info: LabelInfoTypes | model_config = deepcopy(self.config["model"]) + if input_size is not None: + if isinstance(input_size, int): + input_size = (input_size, input_size) + model_config["init_args"]["input_size"] = tuple(model_config["init_args"]["input_size"][:-2]) + input_size + model_cls = get_model_cls_from_config(Namespace(model_config)) if should_pass_label_info(model_cls):