Skip to content

Commit

Permalink
Merge branch 'v2' into infra/v2/tests-on-aws
Browse files Browse the repository at this point in the history
  • Loading branch information
yunchu authored Jan 30, 2024
2 parents 470216f + 5c27276 commit 7c09060
Show file tree
Hide file tree
Showing 18 changed files with 899 additions and 168 deletions.
6 changes: 6 additions & 0 deletions for_developers/cli_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,12 @@ otx train --config <config-file-path> --print_config > config.yaml

## otx {subcommand}

Use Auto-Configuration

```console
otx train --data_root <dataset-root> --task <TASK>
```

Use Configuration file

```console
Expand Down
14 changes: 10 additions & 4 deletions for_developers/setup_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,16 @@ Please see [requirements-lock.txt](requirements-lock.txt). This is what I got af

## Launch training with demo template

- Launch detection task ATSS-R50-FPN template
- Auto-Configuration from dataset & task (Default Model: ATSS-MobilenetV2)

```console
otx train --config src/otx/recipe/detection/atss_r50_fpn.yaml --data_root tests/assets/car_tree_bug --model.num_classes=3 --max_epochs=50 --check_val_every_n_epoch=10 --engine.device gpu --engine.work_dir ./otx-workspace
otx train --data_root tests/assets/car_tree_bug --model.num_classes 3 --engine.device gpu --engine.work_dir ./otx-workspace
```

- Launch detection task ATSS-MobilenetV2 template

```console
otx train --config src/otx/recipe/detection/atss_mobilenetv2.yaml --data_root tests/assets/car_tree_bug --model.num_classes 3 --max_epochs 50 --check_val_every_n_epoch 10 --engine.device gpu --engine.work_dir ./otx-workspace
```

- Change subset names, e.g., "train" -> "train_16" (for training)
Expand All @@ -56,7 +62,7 @@ Please see [requirements-lock.txt](requirements-lock.txt). This is what I got af
- Do train with the existing model checkpoint for resume

```console
otx train ... --checkpoint <checkpoint-path>
otx train ... --engine.checkpoint <checkpoint-path>
```

- Do experiment with deterministic operations and the fixed seed
Expand All @@ -68,7 +74,7 @@ Please see [requirements-lock.txt](requirements-lock.txt). This is what I got af
- Do test with the existing model checkpoint

```console
otx test ... checkpoint=<checkpoint-path>
otx test ... --checkpoint=<checkpoint-path>
```

`--deterministic True` might affect to the model performance. Please see [this link](https://lightning.ai/docs/pytorch/stable/common/trainer.html#deterministic). Therefore, it is not recommended to turn on this option for the model performance comparison.
148 changes: 86 additions & 62 deletions src/otx/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,18 @@

from __future__ import annotations

import sys
from pathlib import Path
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Optional

import yaml
from jsonargparse import ActionConfigFile, ArgumentParser, Namespace, namespace_to_dict
from rich.console import Console

from otx import OTX_LOGO, __version__
from otx.cli.utils import get_otx_root_path
from otx.cli.utils.help_formatter import CustomHelpFormatter
from otx.cli.utils.jsonargparse import get_short_docstring, patch_update_configs
from otx.core.utils.imports import get_otx_root_path

if TYPE_CHECKING:
from jsonargparse._actions import _ActionSubCommands
Expand Down Expand Up @@ -67,11 +68,15 @@ def init_parser(self) -> ArgumentParser:
)
return parser

def subcommand_parser(self, **kwargs) -> ArgumentParser:
"""Returns an ArgumentParser object for parsing command line arguments specific to a subcommand.
@staticmethod
def engine_subcommand_parser(**kwargs) -> ArgumentParser:
"""Creates an ArgumentParser object for the engine subcommand.
Args:
**kwargs: Additional keyword arguments to be passed to the ArgumentParser constructor.
Returns:
ArgumentParser: An ArgumentParser object configured with the specified arguments.
ArgumentParser: The created ArgumentParser object.
"""
parser = ArgumentParser(
formatter_class=CustomHelpFormatter,
Expand All @@ -93,7 +98,7 @@ def subcommand_parser(self, **kwargs) -> ArgumentParser:
)
parser.add_argument(
"--data_root",
type=str,
type=Optional[str],
help="Path to dataset root.",
)
parser.add_argument(
Expand All @@ -111,6 +116,52 @@ def subcommand_parser(self, **kwargs) -> ArgumentParser:
type=str,
help="The metric to monitor the model performance during training callbacks.",
)
engine_skip = {"model", "datamodule", "optimizer", "scheduler"}
parser.add_class_arguments(
Engine,
"engine",
fail_untyped=False,
sub_configs=True,
instantiate=False,
skip=engine_skip,
)
# Model Settings
from otx.core.model.entity.base import OTXModel

model_kwargs: dict[str, Any] = {"fail_untyped": False}

parser.add_subclass_arguments(
OTXModel,
"model",
required=False,
**model_kwargs,
)
# Datamodule Settings
from otx.core.data.module import OTXDataModule

parser.add_class_arguments(
OTXDataModule,
"data",
fail_untyped=False,
sub_configs=True,
)
# Optimizer & Scheduler Settings
from lightning.pytorch.cli import LRSchedulerTypeTuple
from torch.optim import Optimizer

optim_kwargs = {"instantiate": False, "fail_untyped": False, "skip": {"params"}}
scheduler_kwargs = {"instantiate": False, "fail_untyped": False, "skip": {"optimizer"}}
parser.add_subclass_arguments(
baseclass=(Optimizer,),
nested_key="optimizer",
**optim_kwargs,
)
parser.add_subclass_arguments(
baseclass=LRSchedulerTypeTuple,
nested_key="scheduler",
**scheduler_kwargs,
)

return parser

@staticmethod
Expand Down Expand Up @@ -148,84 +199,57 @@ def add_subcommands(self) -> None:
# If environment is not configured to use Engine, do not add a subcommand for Engine.
return
for subcommand in self.engine_subcommands():
sub_parser = self.subcommand_parser()
engine_skip = {"model", "datamodule", "optimizer", "scheduler"}
sub_parser.add_class_arguments(
Engine,
"engine",
fail_untyped=False,
sub_configs=True,
instantiate=False,
skip=engine_skip,
)
sub_parser.link_arguments("data_root", "engine.data_root")
parser_kwargs = self._set_default_config_from_auto_configurator()
sub_parser = self.engine_subcommand_parser(**parser_kwargs)

# Model Settings
from otx.core.model.entity.base import OTXModel

model_kwargs: dict[str, Any] = {"fail_untyped": False}

sub_parser.add_subclass_arguments(
OTXModel,
"model",
required=False,
**model_kwargs,
)
# Datamodule Settings
from otx.core.data.module import OTXDataModule

sub_parser.add_class_arguments(
OTXDataModule,
"data",
fail_untyped=False,
sub_configs=True,
)
sub_parser.link_arguments("data_root", "engine.data_root")
sub_parser.link_arguments("data_root", "data.config.data_root")

# Optimizer & Scheduler Settings
from lightning.pytorch.cli import LRSchedulerTypeTuple
from torch.optim import Optimizer

optim_kwargs = {"instantiate": False, "fail_untyped": False, "skip": {"params"}}
scheduler_kwargs = {"instantiate": False, "fail_untyped": False, "skip": {"optimizer"}}
sub_parser.add_subclass_arguments(
baseclass=(Optimizer,),
nested_key="optimizer",
**optim_kwargs,
)
sub_parser.add_subclass_arguments(
baseclass=LRSchedulerTypeTuple,
nested_key="scheduler",
**scheduler_kwargs,
)

skip: set[str | int] = set(self.engine_subcommands()[subcommand])
fn = getattr(Engine, subcommand)
description = get_short_docstring(fn)

added_arguments = sub_parser.add_method_arguments(
Engine,
subcommand,
skip=skip,
skip=set(self.engine_subcommands()[subcommand]),
fail_untyped=False,
)

if "logger" in added_arguments:
sub_parser.link_arguments("engine.work_dir", "logger.init_args.save_dir")
if "callbacks" in added_arguments:
sub_parser.link_arguments("callback_monitor", "callbacks.init_args.monitor")
sub_parser.link_arguments("engine.work_dir", "callbacks.init_args.dirpath")

# Load default subcommand config file
default_config_file = get_otx_root_path() / "recipe" / "_base_" / f"{subcommand}.yaml"
if default_config_file.exists():
with Path(default_config_file).open() as f:
default_config = yaml.safe_load(f)
sub_parser.set_defaults(**default_config)

if "logger" in added_arguments:
sub_parser.link_arguments("engine.work_dir", "logger.init_args.save_dir")
if "callbacks" in added_arguments:
sub_parser.link_arguments("callback_monitor", "callbacks.init_args.monitor")
sub_parser.link_arguments("engine.work_dir", "callbacks.init_args.dirpath")

self._subcommand_method_arguments[subcommand] = added_arguments
self._subcommand_parsers[subcommand] = sub_parser
parser_subcommands.add_subcommand(subcommand, sub_parser, help=description)

def _set_default_config_from_auto_configurator(self) -> dict:
parser_kwargs = {}
data_root = None
task = None
if "--data_root" in sys.argv:
data_root = sys.argv[sys.argv.index("--data_root") + 1]
if "--task" in sys.argv:
task = sys.argv[sys.argv.index("--task") + 1]
enable_auto_config = data_root is not None and "--config" not in sys.argv
if enable_auto_config:
from otx.core.types.task import OTXTaskType
from otx.engine.utils.auto_configurator import DEFAULT_CONFIG_PER_TASK, AutoConfigurator

auto_configurator = AutoConfigurator(data_root=data_root, task=OTXTaskType(task))
config_file_path = DEFAULT_CONFIG_PER_TASK[auto_configurator.task]
parser_kwargs["default_config_files"] = [config_file_path]
return parser_kwargs

def _set_extension_subcommands_parser(self, parser_subcommands: _ActionSubCommands) -> None:
from otx.cli.install import add_install_parser

Expand Down
21 changes: 0 additions & 21 deletions src/otx/cli/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,3 @@
# SPDX-License-Identifier: Apache-2.0

"""CLI Utils."""

import importlib
import inspect
from pathlib import Path


def get_otx_root_path() -> Path:
"""Return the root path of the otx module.
Returns:
str: The root path of the otx module.
Raises:
ModuleNotFoundError: If the otx module is not found.
"""
otx_module = importlib.import_module("otx")
if otx_module:
file_path = inspect.getfile(otx_module)
return Path(file_path).parent
msg = "Cannot found otx."
raise ModuleNotFoundError(msg)
Loading

0 comments on commit 7c09060

Please sign in to comment.