Skip to content

Commit

Permalink
Add example config for serial_sweeper group + pytest
Browse files Browse the repository at this point in the history
Also fix how `hydra_serial_sweeper` decorator handles job returning Nones or NaNs
  • Loading branch information
nathanpainchaud committed Dec 20, 2024
1 parent e8882bb commit 0e43234
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 6 deletions.
15 changes: 15 additions & 0 deletions src/lightning_hydra_template/configs/serial_sweeper/seeds.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# This custom sweeper tries to emulate closely the syntax of Hydra's built-in BasicSweeper
# https://hydra.cc/docs/tutorials/basic/running_your_app/multi-run/#sweeping-via-hydrasweeperparams

# The keys indicate the config nodes to sweep over, in dot-notation
# The values define how to sweep, and have to be one of the sweep overrides defined by Hydra
# https://hydra.cc/docs/advanced/override_grammar/extended/#sweeps

params:
seed: range(5) # By default, sweep over 5 seeds
# For nested config values, use dot-notation, e.g.:
# model.optimizer.lr: choice(0.1,0.01,0.001)

# If the experiment return a value, e.g. a performance metric, `reduce` defines how to aggregate over the sweep
reduce:
_target_: numpy.nanmean
3 changes: 3 additions & 0 deletions src/lightning_hydra_template/configs/train.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ defaults:
- paths: default
- extras: default
- hydra: default
# serial sweeper is a custom sweeper that emulates the syntax of Hydra's built-in BasicSweeper
# to run multiple experiments in serial in the same process, e.g. cross-validation
- serial_sweeper: null

# experiment configs allow for version control of specific hyperparameters
# e.g. best hyperparameters for given model and datamodule
Expand Down
19 changes: 13 additions & 6 deletions src/lightning_hydra_template/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import copy
import importlib
import itertools
import math
import operator
import warnings
from collections.abc import Callable
Expand Down Expand Up @@ -268,19 +269,25 @@ def wrap(cfg: DictConfig) -> float | None:
# Execute the task function and store the return value
returns.append(task_func(current_cfg))

# Try to guess if the task function is expected to return values
# Warn the user if a mix of values/None are returned
if 0 < sum(return_val is None for return_val in returns) < len(returns):
warning_msg = (
# If task function returns None across the sweep, skip aggregation and warn the user
if all(return_val is None for return_val in returns):
log.warning(
"All iterations of <cfg.serial_sweeper> returned None! \n"
"If you don't need to locally aggregate return values of the sweep (e.g. for hyperparameter "
"optimization), consider switching to Hydra's built-in sweepers."
)
return None

# If some runs seem to have failed, warn the user
if 0 < sum(return_val is None or math.isnan(return_val) for return_val in returns) < len(returns):
log.warning(
"None returned for some iterations of <cfg.serial_sweeper>! \n"
"Return values for the sweep are: \n"
+ "\n".join(
f"{params} -> {return_val}" for params, return_val in zip(params_sets, returns, strict=False)
)
)

log.warning(warning_msg)

# Reduce the return values from the sweep to a single value
return call(serial_sweeper_cfg.reduce, returns)

Expand Down
17 changes: 17 additions & 0 deletions tests/test_sweeps.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,20 @@ def test_optuna_sweep(tmp_path: Path) -> None:
"++trainer.fast_dev_run=true",
] + overrides
run_sh_command(command)


@RunIf(sh=True)
@pytest.mark.slow
def test_serial_sweep(tmp_path: Path) -> None:
"""Test single-process serial sweeping.
Args:
tmp_path: The temporary logging path.
"""
command = [
startfile,
"serial_sweeper=seeds",
"hydra.run.dir=" + str(tmp_path),
"++trainer.fast_dev_run=true",
] + overrides
run_sh_command(command)

0 comments on commit 0e43234

Please sign in to comment.