diff --git a/src/otx/cli/cli.py b/src/otx/cli/cli.py index 7d3f7cecf29..6879ec4e8dc 100644 --- a/src/otx/cli/cli.py +++ b/src/otx/cli/cli.py @@ -101,6 +101,11 @@ def subcommand_parser(self, **kwargs) -> ArgumentParser: type=str, help="Task Type.", ) + parser.add_argument( + "--seed", + type=int, + help="Sets seed for pseudo-random number generators in: pytorch, numpy, python.random.", + ) parser.add_argument( "--callback_monitor", type=str, @@ -119,7 +124,7 @@ def engine_subcommands() -> dict[str, set[str]]: """ device_kwargs = {"accelerator", "devices"} return { - "train": device_kwargs, + "train": {"seed"}.union(device_kwargs), "test": {"datamodule"}.union(device_kwargs), "predict": {"datamodule"}.union(device_kwargs), "export": device_kwargs, @@ -307,6 +312,19 @@ def save_config(self) -> None: skip_check=True, ) + def set_seed(self) -> None: + """Set the random seed for reproducibility. + + This method retrieves the seed value from the argparser and uses it to set the random seed. + If a seed value is provided, it will be used to set the random seed using the + `seed_everything` function from the `lightning` module. + """ + seed = self.get_config_value(self.config, "seed", None) + if seed is not None: + from lightning import seed_everything + + seed_everything(seed, workers=True) + def run(self) -> None: """Executes the specified subcommand. @@ -319,6 +337,7 @@ def run(self) -> None: otx_install(**self.config["install"]) elif self.subcommand in self.engine_subcommands(): + self.set_seed() self.instantiate_classes() fn_kwargs = self._prepare_subcommand_kwargs(self.subcommand) fn = getattr(self.engine, self.subcommand) diff --git a/tests/unit/cli/test_cli.py b/tests/unit/cli/test_cli.py index 6306e5644d8..fcc98376958 100644 --- a/tests/unit/cli/test_cli.py +++ b/tests/unit/cli/test_cli.py @@ -53,7 +53,16 @@ def test_subcommand_parser(self, mocker) -> None: parser = cli.subcommand_parser() assert parser.__class__.__name__ == "ArgumentParser" argument_list = [action.dest for action in parser._actions] - expected_argument = ["help", "verbose", "config", "print_config", "data_root", "task", "callback_monitor"] + expected_argument = [ + "help", + "verbose", + "config", + "print_config", + "data_root", + "task", + "seed", + "callback_monitor", + ] assert sorted(argument_list) == sorted(expected_argument) def test_add_subcommands(self, mocker) -> None: