Skip to content

Commit

Permalink
update interface
Browse files Browse the repository at this point in the history
  • Loading branch information
eunwoosh committed Aug 5, 2024
1 parent bf0c736 commit 8113046
Show file tree
Hide file tree
Showing 6 changed files with 19 additions and 20 deletions.
1 change: 1 addition & 0 deletions src/otx/algo/detection/atss.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def __init__(
torch_compile=torch_compile,
tile_config=tile_config,
)
breakpoint()
self.tile_image_size = tile_image_size

@property
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -548,7 +548,6 @@ def bbox_loss(self, x: tuple[Tensor], sampling_results: list[SamplingResult], ba

class CustomConvFCBBoxHead(Shared2FCBBoxHead, ClassIncrementalMixin):
"""CustomConvFCBBoxHead class for OTX."""
# checked

def loss_and_target(
self,
Expand Down
12 changes: 5 additions & 7 deletions src/otx/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,18 +331,16 @@ def instantiate_classes(self, instantiate_engine: bool = True) -> None:
# For num_classes update, Model and Metric are instantiated separately.
model_config = self.config[self.subcommand].pop("model")

input_size = self.config["train"]["engine"].get("input_size")
if input_size is not None:
if isinstance(input_size, int):
input_size = (input_size, input_size)
self.config["train"]["data"]["input_size"] = input_size
model_config["init_args"]["input_size"] = tuple(model_config["init_args"]["input_size"][:-2]) + input_size

# Instantiate the things that don't need to special handling
self.config_init = self.parser.instantiate_classes(self.config)
self.workspace = self.get_config_value(self.config_init, "workspace")
self.datamodule = self.get_config_value(self.config_init, "data")

if (input_size := self.datamodule.input_size) is not None:
if isinstance(input_size, int):
input_size = (input_size, input_size)
model_config["init_args"]["input_size"] = tuple(model_config["init_args"]["input_size"][:-2]) + input_size

# Instantiate the model and needed components
self.model = self.instantiate_model(model_config=model_config)

Expand Down
5 changes: 5 additions & 0 deletions src/otx/core/data/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,17 +63,22 @@ def __init__(
auto_num_workers: bool = False,
device: DeviceType = DeviceType.auto,
input_size: int | tuple[int, int] | None = None,
adaptive_input_size: bool = False,
) -> None:
"""Constructor."""
super().__init__()
self.task = task
self.data_format = data_format
self.data_root = data_root

if adaptive_input_size:
print("adaptive_input_size works")

if input_size is not None:
for subset_cfg in [train_subset, val_subset, test_subset, unlabeled_subset]:
if subset_cfg.input_size is None:
subset_cfg.input_size = input_size
self.input_size = input_size

self.train_subset = train_subset
self.val_subset = val_subset
Expand Down
11 changes: 1 addition & 10 deletions src/otx/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,6 @@ def __init__(
checkpoint: PathLike | None = None,
device: DeviceType = DeviceType.auto,
num_devices: int = 1,
input_size: Sequence[int] | int | None = None,
**kwargs,
):
"""Initializes the OTX Engine.
Expand All @@ -147,17 +146,8 @@ def __init__(
data_root=data_root,
task=datamodule.task if datamodule is not None else task,
model_name=None if isinstance(model, OTXModel) else model,
input_size=input_size,
)

if input_size is not None:
if isinstance(datamodule, OTXDataModule) and datamodule.input_size != input_size:
msg = "Data module is already initialized. Input size will be ignored to data module."
logging.warning(msg)
if isinstance(model, OTXModel) and model.input_size != input_size:
msg = "Model is already initialized. Input size will be ignored to model."
logging.warning(msg)

self._datamodule: OTXDataModule | None = (
datamodule if datamodule is not None else self._auto_configurator.get_datamodule()
)
Expand All @@ -169,6 +159,7 @@ def __init__(
if isinstance(model, OTXModel)
else self._auto_configurator.get_model(
label_info=self._datamodule.label_info if self._datamodule is not None else None,
input_size=self._datamodule.input_size,
)
)

Expand Down
9 changes: 7 additions & 2 deletions src/otx/engine/utils/auto_configurator.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
],
"common_semantic_segmentation_with_subset_dirs": [OTXTaskType.SEMANTIC_SEGMENTATION],
"kinetics": [OTXTaskType.ACTION_CLASSIFICATION],
"mvtec_classification": [OTXTaskType.ANOMALY_CLASSIFICATION, OTXTaskType.ANOMALY_DETECTION, OTXTaskType.ANOMALY_SEGMENTATION],
"mvtec": [OTXTaskType.ANOMALY_CLASSIFICATION, OTXTaskType.ANOMALY_DETECTION, OTXTaskType.ANOMALY_SEGMENTATION],
}

OVMODEL_PER_TASK = {
Expand Down Expand Up @@ -245,7 +245,7 @@ def get_datamodule(self) -> OTXDataModule | None:
**data_config,
)

def get_model(self, model_name: str | None = None, label_info: LabelInfoTypes | None = None) -> OTXModel:
def get_model(self, model_name: str | None = None, label_info: LabelInfoTypes | None = None, input_size: Sequence[int] | None = None) -> OTXModel:
"""Retrieves the OTXModel instance based on the provided model name and meta information.
Args:
Expand Down Expand Up @@ -278,6 +278,11 @@ def get_model(self, model_name: str | None = None, label_info: LabelInfoTypes |

model_config = deepcopy(self.config["model"])

if input_size is not None:
if isinstance(input_size, int):
input_size = (input_size, input_size)
model_config["init_args"]["input_size"] = tuple(model_config["init_args"]["input_size"][:-2]) + input_size

model_cls = get_model_cls_from_config(Namespace(model_config))

if should_pass_label_info(model_cls):
Expand Down

0 comments on commit 8113046

Please sign in to comment.