Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Turn on setting seed and deterministic flags #2700

Merged
merged 2 commits into from
Dec 6, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 17 additions & 9 deletions for_developers/setup_guide.md
Original file line number Diff line number Diff line change
@@ -51,18 +51,26 @@ Please see [requirements-lock.txt](requirements-lock.txt). This is what I got af

- Launch detection task ATSS-R50-FPN template

```console
otx train +recipe=detection/atss_r50_fpn base.data_dir=tests/assets/car_tree_bug model.otx_model.config.bbox_head.num_classes=3 trainer.max_epochs=50 trainer.check_val_every_n_epoch=10 trainer=gpu base.work_dir=outputs/test_work_dir base.output_dir=outputs/test_output_dir
```
```console
otx train +recipe=detection/atss_r50_fpn base.data_dir=tests/assets/car_tree_bug model.otx_model.config.bbox_head.num_classes=3 trainer.max_epochs=50 trainer.check_val_every_n_epoch=10 trainer=gpu base.work_dir=outputs/test_work_dir base.output_dir=outputs/test_output_dir
```

- Change subset names, e.g., "train" -> "train_16" (for training)

```console
otx train ... data.train_subset.subset_name=<arbitrary-name> data.val_subset.subset_name=<arbitrary-name> data.test_subset.subset_name=<arbitrary-name>
```
```console
otx train ... data.train_subset.subset_name=<arbitrary-name> data.val_subset.subset_name=<arbitrary-name> data.test_subset.subset_name=<arbitrary-name>
```

- Do test with the best validation model checkpoint

```console
otx train ... test=true
```
```console
otx train ... test=true
```

- Do experiment with deterministic operations and the fixed seed

```console
otx train ... trainer.deterministic=True seed=<arbitrary-seed>
```

`trainer.deterministic=True` might affect to the model performance. Please see [this link](https://lightning.ai/docs/pytorch/stable/common/trainer.html#deterministic). Therefore, it is not recommended to turn on this option for the model performance comparison.
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -269,6 +269,10 @@ max-returns = 10
"src/otx/**/*.py" = [
"ERA001",
]
# See https://github.com/openvinotoolkit/training_extensions/actions/runs/7109500350/job/19354528819?pr=2700
"src/otx/core/config/**/*.py" = [
"UP007"
]

[tool.ruff.pydocstyle]
convention = "google"
4 changes: 4 additions & 0 deletions src/otx/config/train.yaml
Original file line number Diff line number Diff line change
@@ -11,6 +11,10 @@ defaults:
train: true
test: false

# If set it with an integer value, e.g. `seed: 1`,
# Lightning derives unique seeds across all dataloader workers and processes for torch, numpy and stdlib random number generators.
seed: null

hydra:
searchpath:
- pkg://otx
15 changes: 13 additions & 2 deletions src/otx/core/config/__init__.py
Original file line number Diff line number Diff line change
@@ -2,6 +2,9 @@
# SPDX-License-Identifier: Apache-2.0
#
"""Config data type objects."""

from __future__ import annotations

from dataclasses import dataclass
from typing import Optional

@@ -13,18 +16,26 @@

@dataclass
class TrainConfig:
"""DTO for training."""
"""DTO for training.

Attributes:
seed: If set it with an integer value, e.g. `seed=1`,
Lightning derives unique seeds across all dataloader workers and processes
for torch, numpy and stdlib random number generators.
"""

base: BaseConfig
callbacks: dict
data: DataModuleConfig
trainer: TrainerConfig
model: ModelConfig
logger: dict
recipe: Optional[str] # noqa: FA100
recipe: Optional[str]
train: bool
test: bool

seed: Optional[int] = None


def register_configs() -> None:
"""Register DTO as default to hydra."""
6 changes: 4 additions & 2 deletions src/otx/core/engine/train.py
Original file line number Diff line number Diff line change
@@ -47,6 +47,8 @@ def train(cfg: TrainConfig) -> tuple[Trainer, dict[str, Any]]:
:param cfg: A DictConfig configuration composed by Hydra.
:return: A tuple with Pytorch Lightning Trainer and Python dict of metrics
"""
from lightning import seed_everything

from otx.core.data.module import OTXDataModule
from otx.core.engine.utils.instantiators import (
instantiate_callbacks,
@@ -55,8 +57,8 @@ def train(cfg: TrainConfig) -> tuple[Trainer, dict[str, Any]]:
from otx.core.engine.utils.logging_utils import log_hyperparameters

# set seed for random number generators in pytorch, numpy and python.random
# if cfg.get("seed"):
# L.seed_everything(cfg.seed, workers=True)
if cfg.seed is not None:
seed_everything(cfg.seed, workers=True)

log.info(f"Instantiating datamodule <{cfg.data}>")
datamodule = OTXDataModule(task=cfg.base.task, config=cfg.data)