From 98b7326380c3f743623e89c8034940ccadfe178a Mon Sep 17 00:00:00 2001 From: harimkang Date: Thu, 25 Jan 2024 09:48:39 +0900 Subject: [PATCH 1/5] Move seed into CLI side --- src/otx/cli/cli.py | 19 +++++++++++++++++++ src/otx/engine/engine.py | 8 +------- src/otx/recipe/_base_/train.yaml | 1 - 3 files changed, 20 insertions(+), 8 deletions(-) diff --git a/src/otx/cli/cli.py b/src/otx/cli/cli.py index 7d3f7cecf29..06330f6a674 100644 --- a/src/otx/cli/cli.py +++ b/src/otx/cli/cli.py @@ -101,6 +101,11 @@ def subcommand_parser(self, **kwargs) -> ArgumentParser: type=str, help="Task Type.", ) + parser.add_argument( + "--seed", + type=int, + help="Sets seed for pseudo-random number generators in: pytorch, numpy, python.random.", + ) parser.add_argument( "--callback_monitor", type=str, @@ -307,6 +312,19 @@ def save_config(self) -> None: skip_check=True, ) + def set_seed(self) -> None: + """Set the random seed for reproducibility. + + This method retrieves the seed value from the argparser and uses it to set the random seed. + If a seed value is provided, it will be used to set the random seed using the + `seed_everything` function from the `lightning` module. + """ + seed = self.get_config_value(self.config, "seed", None) + if seed is not None: + from lightning import seed_everything + + seed_everything(seed, workers=True) + def run(self) -> None: """Executes the specified subcommand. @@ -319,6 +337,7 @@ def run(self) -> None: otx_install(**self.config["install"]) elif self.subcommand in self.engine_subcommands(): + self.set_seed() self.instantiate_classes() fn_kwargs = self._prepare_subcommand_kwargs(self.subcommand) fn = getattr(self.engine, self.subcommand) diff --git a/src/otx/engine/engine.py b/src/otx/engine/engine.py index 2a2c7e17f29..3ca7958ce22 100644 --- a/src/otx/engine/engine.py +++ b/src/otx/engine/engine.py @@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Any, Iterable import torch -from lightning import Trainer, seed_everything +from lightning import Trainer from otx.core.config.device import DeviceConfig from otx.core.data.module import OTXDataModule @@ -126,7 +126,6 @@ def __init__( def train( self, max_epochs: int = 10, - seed: int | None = None, deterministic: bool = False, precision: _PRECISION_INPUT | None = "32", val_check_interval: int | float | None = None, @@ -139,7 +138,6 @@ def train( Args: max_epochs (int | None, optional): The maximum number of epochs. Defaults to None. - seed (int | None, optional): The random seed. Defaults to None. deterministic (bool | None, optional): Whether to enable deterministic behavior. Defaults to False. precision (_PRECISION_INPUT | None, optional): The precision of the model. Defaults to 32. val_check_interval (int | float | None, optional): The validation check interval. Defaults to None. @@ -154,7 +152,6 @@ def train( Example: >>> engine.train( ... max_epochs=3, - ... seed=1234, ... deterministic=False, ... precision="32", ... ) @@ -188,9 +185,6 @@ def train( ) lit_module.meta_info = self.datamodule.meta_info - if seed is not None: - seed_everything(seed, workers=True) - self._build_trainer( logger=logger, callbacks=callbacks, diff --git a/src/otx/recipe/_base_/train.yaml b/src/otx/recipe/_base_/train.yaml index 64ab9e551dc..df0ede1ae34 100644 --- a/src/otx/recipe/_base_/train.yaml +++ b/src/otx/recipe/_base_/train.yaml @@ -50,7 +50,6 @@ logger: default_hp_metric: true prefix: "" deterministic: false -seed: null precision: 16 check_val_every_n_epoch: 1 gradient_clip_val: null From 862a992111176440e3544d6a717360670508b3bc Mon Sep 17 00:00:00 2001 From: harimkang Date: Thu, 25 Jan 2024 10:04:12 +0900 Subject: [PATCH 2/5] Fix unit test cli --- tests/unit/cli/test_cli.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/tests/unit/cli/test_cli.py b/tests/unit/cli/test_cli.py index 6306e5644d8..fcc98376958 100644 --- a/tests/unit/cli/test_cli.py +++ b/tests/unit/cli/test_cli.py @@ -53,7 +53,16 @@ def test_subcommand_parser(self, mocker) -> None: parser = cli.subcommand_parser() assert parser.__class__.__name__ == "ArgumentParser" argument_list = [action.dest for action in parser._actions] - expected_argument = ["help", "verbose", "config", "print_config", "data_root", "task", "callback_monitor"] + expected_argument = [ + "help", + "verbose", + "config", + "print_config", + "data_root", + "task", + "seed", + "callback_monitor", + ] assert sorted(argument_list) == sorted(expected_argument) def test_add_subcommands(self, mocker) -> None: From 7702eb7c6dd9b2b9e7331fe5428aecf70d3a7e01 Mon Sep 17 00:00:00 2001 From: harimkang Date: Thu, 25 Jan 2024 10:08:53 +0900 Subject: [PATCH 3/5] Revert train.yaml --- src/otx/recipe/_base_/train.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/src/otx/recipe/_base_/train.yaml b/src/otx/recipe/_base_/train.yaml index df0ede1ae34..64ab9e551dc 100644 --- a/src/otx/recipe/_base_/train.yaml +++ b/src/otx/recipe/_base_/train.yaml @@ -50,6 +50,7 @@ logger: default_hp_metric: true prefix: "" deterministic: false +seed: null precision: 16 check_val_every_n_epoch: 1 gradient_clip_val: null From a21247e203655a3a61a2c2c418052e881314a337 Mon Sep 17 00:00:00 2001 From: harimkang Date: Thu, 25 Jan 2024 15:39:08 +0900 Subject: [PATCH 4/5] Revert Engine side --- src/otx/engine/engine.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/otx/engine/engine.py b/src/otx/engine/engine.py index 3ca7958ce22..2a2c7e17f29 100644 --- a/src/otx/engine/engine.py +++ b/src/otx/engine/engine.py @@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Any, Iterable import torch -from lightning import Trainer +from lightning import Trainer, seed_everything from otx.core.config.device import DeviceConfig from otx.core.data.module import OTXDataModule @@ -126,6 +126,7 @@ def __init__( def train( self, max_epochs: int = 10, + seed: int | None = None, deterministic: bool = False, precision: _PRECISION_INPUT | None = "32", val_check_interval: int | float | None = None, @@ -138,6 +139,7 @@ def train( Args: max_epochs (int | None, optional): The maximum number of epochs. Defaults to None. + seed (int | None, optional): The random seed. Defaults to None. deterministic (bool | None, optional): Whether to enable deterministic behavior. Defaults to False. precision (_PRECISION_INPUT | None, optional): The precision of the model. Defaults to 32. val_check_interval (int | float | None, optional): The validation check interval. Defaults to None. @@ -152,6 +154,7 @@ def train( Example: >>> engine.train( ... max_epochs=3, + ... seed=1234, ... deterministic=False, ... precision="32", ... ) @@ -185,6 +188,9 @@ def train( ) lit_module.meta_info = self.datamodule.meta_info + if seed is not None: + seed_everything(seed, workers=True) + self._build_trainer( logger=logger, callbacks=callbacks, From dea2fd06e96539de14b4ef1d74b91dffb32f11cc Mon Sep 17 00:00:00 2001 From: harimkang Date: Thu, 25 Jan 2024 15:52:39 +0900 Subject: [PATCH 5/5] Add skip argument seed with train in CLI --- src/otx/cli/cli.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/otx/cli/cli.py b/src/otx/cli/cli.py index 06330f6a674..6879ec4e8dc 100644 --- a/src/otx/cli/cli.py +++ b/src/otx/cli/cli.py @@ -124,7 +124,7 @@ def engine_subcommands() -> dict[str, set[str]]: """ device_kwargs = {"accelerator", "devices"} return { - "train": device_kwargs, + "train": {"seed"}.union(device_kwargs), "test": {"datamodule"}.union(device_kwargs), "predict": {"datamodule"}.union(device_kwargs), "export": device_kwargs,