diff --git a/for_developers/setup_guide.md b/for_developers/setup_guide.md index 9379d5a8ce8..45bd5a6830b 100644 --- a/for_developers/setup_guide.md +++ b/for_developers/setup_guide.md @@ -51,18 +51,26 @@ Please see [requirements-lock.txt](requirements-lock.txt). This is what I got af - Launch detection task ATSS-R50-FPN template -```console -otx train +recipe=detection/atss_r50_fpn base.data_dir=tests/assets/car_tree_bug model.otx_model.config.bbox_head.num_classes=3 trainer.max_epochs=50 trainer.check_val_every_n_epoch=10 trainer=gpu base.work_dir=outputs/test_work_dir base.output_dir=outputs/test_output_dir -``` + ```console + otx train +recipe=detection/atss_r50_fpn base.data_dir=tests/assets/car_tree_bug model.otx_model.config.bbox_head.num_classes=3 trainer.max_epochs=50 trainer.check_val_every_n_epoch=10 trainer=gpu base.work_dir=outputs/test_work_dir base.output_dir=outputs/test_output_dir + ``` - Change subset names, e.g., "train" -> "train_16" (for training) -```console -otx train ... data.train_subset.subset_name= data.val_subset.subset_name= data.test_subset.subset_name= -``` + ```console + otx train ... data.train_subset.subset_name= data.val_subset.subset_name= data.test_subset.subset_name= + ``` - Do test with the best validation model checkpoint -```console -otx train ... test=true -``` + ```console + otx train ... test=true + ``` + +- Do experiment with deterministic operations and the fixed seed + + ```console + otx train ... trainer.deterministic=True seed= + ``` + + `trainer.deterministic=True` might affect to the model performance. Please see [this link](https://lightning.ai/docs/pytorch/stable/common/trainer.html#deterministic). Therefore, it is not recommended to turn on this option for the model performance comparison. diff --git a/pyproject.toml b/pyproject.toml index 9694a1be4fc..1d553145fad 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -269,6 +269,10 @@ max-returns = 10 "src/otx/**/*.py" = [ "ERA001", ] +# See https://github.com/openvinotoolkit/training_extensions/actions/runs/7109500350/job/19354528819?pr=2700 +"src/otx/core/config/**/*.py" = [ + "UP007" +] [tool.ruff.pydocstyle] convention = "google" diff --git a/src/otx/config/train.yaml b/src/otx/config/train.yaml index 8ff42e98638..bf2e8a466ef 100644 --- a/src/otx/config/train.yaml +++ b/src/otx/config/train.yaml @@ -11,6 +11,10 @@ defaults: train: true test: false +# If set it with an integer value, e.g. `seed: 1`, +# Lightning derives unique seeds across all dataloader workers and processes for torch, numpy and stdlib random number generators. +seed: null + hydra: searchpath: - pkg://otx diff --git a/src/otx/core/config/__init__.py b/src/otx/core/config/__init__.py index fd7c1081c47..21bb6a29642 100644 --- a/src/otx/core/config/__init__.py +++ b/src/otx/core/config/__init__.py @@ -2,6 +2,9 @@ # SPDX-License-Identifier: Apache-2.0 # """Config data type objects.""" + +from __future__ import annotations + from dataclasses import dataclass from typing import Optional @@ -13,7 +16,13 @@ @dataclass class TrainConfig: - """DTO for training.""" + """DTO for training. + + Attributes: + seed: If set it with an integer value, e.g. `seed=1`, + Lightning derives unique seeds across all dataloader workers and processes + for torch, numpy and stdlib random number generators. + """ base: BaseConfig callbacks: dict @@ -21,10 +30,12 @@ class TrainConfig: trainer: TrainerConfig model: ModelConfig logger: dict - recipe: Optional[str] # noqa: FA100 + recipe: Optional[str] train: bool test: bool + seed: Optional[int] = None + def register_configs() -> None: """Register DTO as default to hydra.""" diff --git a/src/otx/core/engine/train.py b/src/otx/core/engine/train.py index e8d0e31c13d..197894d7ca0 100644 --- a/src/otx/core/engine/train.py +++ b/src/otx/core/engine/train.py @@ -47,6 +47,8 @@ def train(cfg: TrainConfig) -> tuple[Trainer, dict[str, Any]]: :param cfg: A DictConfig configuration composed by Hydra. :return: A tuple with Pytorch Lightning Trainer and Python dict of metrics """ + from lightning import seed_everything + from otx.core.data.module import OTXDataModule from otx.core.engine.utils.instantiators import ( instantiate_callbacks, @@ -55,8 +57,8 @@ def train(cfg: TrainConfig) -> tuple[Trainer, dict[str, Any]]: from otx.core.engine.utils.logging_utils import log_hyperparameters # set seed for random number generators in pytorch, numpy and python.random - # if cfg.get("seed"): - # L.seed_everything(cfg.seed, workers=True) + if cfg.seed is not None: + seed_everything(cfg.seed, workers=True) log.info(f"Instantiating datamodule <{cfg.data}>") datamodule = OTXDataModule(task=cfg.base.task, config=cfg.data)