diff --git a/src/lightning_hydra_template/configs/serial_sweeper/seeds.yaml b/src/lightning_hydra_template/configs/serial_sweeper/seeds.yaml new file mode 100644 index 000000000..abf808ed7 --- /dev/null +++ b/src/lightning_hydra_template/configs/serial_sweeper/seeds.yaml @@ -0,0 +1,15 @@ +# This custom sweeper tries to emulate closely the syntax of Hydra's built-in BasicSweeper +# https://hydra.cc/docs/tutorials/basic/running_your_app/multi-run/#sweeping-via-hydrasweeperparams + +# The keys indicate the config nodes to sweep over, in dot-notation +# The values define how to sweep, and have to be one of the sweep overrides defined by Hydra +# https://hydra.cc/docs/advanced/override_grammar/extended/#sweeps + +params: + seed: range(5) # By default, sweep over 5 seeds + # For nested config values, use dot-notation, e.g.: + # model.optimizer.lr: choice(0.1,0.01,0.001) + +# If the experiment return a value, e.g. a performance metric, `reduce` defines how to aggregate over the sweep +reduce: + _target_: numpy.nanmean diff --git a/src/lightning_hydra_template/configs/train.yaml b/src/lightning_hydra_template/configs/train.yaml index 566b84402..e49746af0 100644 --- a/src/lightning_hydra_template/configs/train.yaml +++ b/src/lightning_hydra_template/configs/train.yaml @@ -12,6 +12,9 @@ defaults: - paths: default - extras: default - hydra: default + # serial sweeper is a custom sweeper that emulates the syntax of Hydra's built-in BasicSweeper + # to run multiple experiments in serial in the same process, e.g. cross-validation + - serial_sweeper: null # experiment configs allow for version control of specific hyperparameters # e.g. best hyperparameters for given model and datamodule diff --git a/src/lightning_hydra_template/utils/utils.py b/src/lightning_hydra_template/utils/utils.py index f4260cbcc..c1fd6bceb 100644 --- a/src/lightning_hydra_template/utils/utils.py +++ b/src/lightning_hydra_template/utils/utils.py @@ -2,6 +2,7 @@ import copy import importlib import itertools +import math import operator import warnings from collections.abc import Callable @@ -268,10 +269,18 @@ def wrap(cfg: DictConfig) -> float | None: # Execute the task function and store the return value returns.append(task_func(current_cfg)) - # Try to guess if the task function is expected to return values - # Warn the user if a mix of values/None are returned - if 0 < sum(return_val is None for return_val in returns) < len(returns): - warning_msg = ( + # If task function returns None across the sweep, skip aggregation and warn the user + if all(return_val is None for return_val in returns): + log.warning( + "All iterations of returned None! \n" + "If you don't need to locally aggregate return values of the sweep (e.g. for hyperparameter " + "optimization), consider switching to Hydra's built-in sweepers." + ) + return None + + # If some runs seem to have failed, warn the user + if 0 < sum(return_val is None or math.isnan(return_val) for return_val in returns) < len(returns): + log.warning( "None returned for some iterations of ! \n" "Return values for the sweep are: \n" + "\n".join( @@ -279,8 +288,6 @@ def wrap(cfg: DictConfig) -> float | None: ) ) - log.warning(warning_msg) - # Reduce the return values from the sweep to a single value return call(serial_sweeper_cfg.reduce, returns) diff --git a/tests/test_sweeps.py b/tests/test_sweeps.py index 0af7f2293..243512a3e 100644 --- a/tests/test_sweeps.py +++ b/tests/test_sweeps.py @@ -64,3 +64,20 @@ def test_optuna_sweep(tmp_path: Path) -> None: "++trainer.fast_dev_run=true", ] + overrides run_sh_command(command) + + +@RunIf(sh=True) +@pytest.mark.slow +def test_serial_sweep(tmp_path: Path) -> None: + """Test single-process serial sweeping. + + Args: + tmp_path: The temporary logging path. + """ + command = [ + startfile, + "serial_sweeper=seeds", + "hydra.run.dir=" + str(tmp_path), + "++trainer.fast_dev_run=true", + ] + overrides + run_sh_command(command)