Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add auto-config num_classes in CLI side #2861

Merged
merged 5 commits into from
Feb 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 36 additions & 3 deletions src/otx/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import sys
from pathlib import Path
from typing import TYPE_CHECKING, Any, Optional
from warnings import warn

import yaml
from jsonargparse import ActionConfigFile, ArgumentParser, Namespace, namespace_to_dict
Expand Down Expand Up @@ -117,6 +118,14 @@ def engine_subcommand_parser(**kwargs) -> ArgumentParser:
type=str,
help="The metric to monitor the model performance during training callbacks.",
)
parser.add_argument(
"--disable-infer-num-classes",
help="OTX automatically infers num_classes from the given dataset "
"and applies it to the model initialization."
"Consequently, there might be a mismatch with the provided model configuration during runtime. "
"Setting this option to true will disable this behavior.",
action="store_true",
)
engine_skip = {"model", "datamodule", "optimizer", "scheduler"}
parser.add_class_arguments(
Engine,
Expand Down Expand Up @@ -279,9 +288,11 @@ def instantiate_classes(self) -> None:
If it is, it instantiates the necessary classes such as config, datamodule, model, and engine.
"""
if self.subcommand in self.engine_subcommands():
# For num_classes update, Model is instantiated separately.
model_config = self.config[self.subcommand].pop("model")
self.config_init = self.parser.instantiate_classes(self.config)
self.datamodule = self.get_config_value(self.config_init, "data")
self.model, optimizer, scheduler = self.instantiate_model()
self.model, optimizer, scheduler = self.instantiate_model(model_config=model_config)

engine_kwargs = self.get_config_value(self.config_init, "engine")
self.engine = Engine(
Expand All @@ -292,16 +303,38 @@ def instantiate_classes(self) -> None:
**engine_kwargs,
)

def instantiate_model(self) -> tuple:
def instantiate_model(self, model_config: Namespace) -> tuple:
"""Instantiate the model based on the subcommand.

This method checks if the subcommand is one of the engine subcommands.
If it is, it instantiates the model.

Args:
model_config (Namespace): The model configuration.

Returns:
tuple: The model and optimizer and scheduler.
"""
model = self.get_config_value(self.config_init, "model")
from otx.core.model.entity.base import OTXModel
from otx.engine.utils.auto_configurator import get_num_classes_from_meta_info

# Update num_classes
if not self.get_config_value(self.config_init, "disable_infer_num_classes", False):
num_classes = get_num_classes_from_meta_info(task=self.datamodule.task, meta_info=self.datamodule.meta_info)
if num_classes != model_config.init_args.num_classes:
warning_msg = (
f"The `num_classes` in dataset is {num_classes} "
f"but, the `num_classes` of model is {model_config.init_args.num_classes}. "
f"So, Update `model.num_classes` to {num_classes}."
)
warn(warning_msg, stacklevel=0)
model_config.init_args.num_classes = num_classes

# Parses the OTXModel separately to update num_classes.
model_parser = ArgumentParser()
model_parser.add_subclass_arguments(OTXModel, "model", required=False, fail_untyped=False)
model = model_parser.instantiate_classes(Namespace(model=model_config)).get("model")

optimizer_kwargs = namespace_to_dict(self.get_config_value(self.config_init, "optimizer", Namespace()))
scheduler_kwargs = namespace_to_dict(self.get_config_value(self.config_init, "scheduler", Namespace()))
from otx.core.utils.instantiators import partial_instantiate_class
Expand Down
2 changes: 1 addition & 1 deletion src/otx/cli/utils/jsonargparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def get_configuration(config_path: str | Path) -> dict:
logger.info(f"{config_path} is loaded.")

# Remove unnecessary cli arguments for API usage
cli_args = ["verbose", "data_root", "task", "seed", "callback_monitor", "resume"]
cli_args = ["verbose", "data_root", "task", "seed", "callback_monitor", "resume", "disable_infer_num_classes"]
logger.warning(f"The corresponding keys in config are not used.: {cli_args}")
for arg in cli_args:
config.pop(arg, None)
Expand Down
21 changes: 0 additions & 21 deletions tests/integration/cli/test_auto_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,27 +10,6 @@
from otx.engine.utils.auto_configurator import DEFAULT_CONFIG_PER_TASK


# [TODO]: Please Remove this with auto num_classes update feature in CLI
@pytest.fixture()
def fxt_cli_override_command_per_task() -> dict:
return {
"multi_class_cls": ["--model.num_classes", "2"],
"multi_label_cls": ["--model.num_classes", "2"],
"detection": ["--model.num_classes", "3"],
"instance_segmentation": ["--model.num_classes", "3"],
"semantic_segmentation": ["--model.num_classes", "2"],
"action_classification": ["--model.num_classes", "2"],
"action_detection": [
"--model.num_classes",
"5",
"--model.topk",
"3",
],
"visual_prompting": [],
"zero_shot_visual_prompting": ["--max_epochs", "1"],
}


@pytest.mark.parametrize("task", [task.value.lower() for task in DEFAULT_CONFIG_PER_TASK])
def test_otx_cli_auto_configuration(
task: str,
Expand Down
1 change: 1 addition & 0 deletions tests/integration/cli/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,7 @@ def test_otx_ov_test(recipe: str, tmp_path: Path, fxt_target_dataset_per_task: d
str(tmp_path_test / "outputs"),
"--engine.device",
"cpu",
"--disable-infer-num-classes",
]

with patch("sys.argv", command_cfg):
Expand Down
22 changes: 11 additions & 11 deletions tests/integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,26 +40,26 @@ def fxt_target_dataset_per_task() -> dict:
@pytest.fixture()
def fxt_cli_override_command_per_task() -> dict:
return {
"multi_class_cls": ["--model.num_classes", "2"],
"multi_label_cls": ["--model.num_classes", "2"],
"multi_class_cls": [],
"multi_label_cls": [],
"h_label_cls": [
"--model.num_classes",
"7",
"--model.num_multiclass_heads",
"2",
"--model.num_multilabel_classes",
"3",
],
"detection": ["--model.num_classes", "3"],
"instance_segmentation": ["--model.num_classes", "3"],
"semantic_segmentation": ["--model.num_classes", "2"],
"action_classification": ["--model.num_classes", "2"],
"detection": [],
"instance_segmentation": [],
"semantic_segmentation": [],
"action_classification": [],
"action_detection": [
"--model.num_classes",
"5",
"--model.topk",
"3",
],
"visual_prompting": [],
"zero_shot_visual_prompting": [],
"zero_shot_visual_prompting": [
"--max_epochs",
"1",
"--disable-infer-num-classes",
],
}
Loading