diff --git a/sheeprl/algos/dreamer_v1/dreamer_v1.py b/sheeprl/algos/dreamer_v1/dreamer_v1.py index ab84237a..a6da7b23 100644 --- a/sheeprl/algos/dreamer_v1/dreamer_v1.py +++ b/sheeprl/algos/dreamer_v1/dreamer_v1.py @@ -396,8 +396,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): device = fabric.device rank = fabric.global_rank world_size = fabric.world_size - fabric.seed_everything(cfg.seed) - torch.backends.cudnn.deterministic = cfg.torch_deterministic if cfg.checkpoint.resume_from: state = fabric.load(cfg.checkpoint.resume_from) diff --git a/sheeprl/algos/dreamer_v2/dreamer_v2.py b/sheeprl/algos/dreamer_v2/dreamer_v2.py index 9bef5dbb..4d4d561f 100644 --- a/sheeprl/algos/dreamer_v2/dreamer_v2.py +++ b/sheeprl/algos/dreamer_v2/dreamer_v2.py @@ -1,6 +1,7 @@ """Dreamer-V2 implementation from [https://arxiv.org/abs/2010.02193](https://arxiv.org/abs/2010.02193). Adapted from the original implementation from https://github.com/danijar/dreamerv2 """ + from __future__ import annotations import copy @@ -419,8 +420,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): device = fabric.device rank = fabric.global_rank world_size = fabric.world_size - fabric.seed_everything(cfg.seed) - torch.backends.cudnn.deterministic = cfg.torch_deterministic if cfg.checkpoint.resume_from: state = fabric.load(cfg.checkpoint.resume_from) diff --git a/sheeprl/algos/dreamer_v3/dreamer_v3.py b/sheeprl/algos/dreamer_v3/dreamer_v3.py index 94e56f75..49634131 100644 --- a/sheeprl/algos/dreamer_v3/dreamer_v3.py +++ b/sheeprl/algos/dreamer_v3/dreamer_v3.py @@ -384,8 +384,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): device = fabric.device rank = fabric.global_rank world_size = fabric.world_size - fabric.seed_everything(cfg.seed) - torch.backends.cudnn.deterministic = cfg.torch_deterministic if cfg.checkpoint.resume_from: state = fabric.load(cfg.checkpoint.resume_from) diff --git a/sheeprl/algos/droq/droq.py b/sheeprl/algos/droq/droq.py index 1900c589..57510355 100644 --- a/sheeprl/algos/droq/droq.py +++ b/sheeprl/algos/droq/droq.py @@ -140,8 +140,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): device = fabric.device rank = fabric.global_rank world_size = fabric.world_size - fabric.seed_everything(cfg.seed) - torch.backends.cudnn.deterministic = cfg.torch_deterministic # Resume from checkpoint if cfg.checkpoint.resume_from: diff --git a/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py b/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py index d91e0334..3af2c83c 100644 --- a/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py +++ b/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py @@ -394,8 +394,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): device = fabric.device rank = fabric.global_rank world_size = fabric.world_size - fabric.seed_everything(cfg.seed) - torch.backends.cudnn.deterministic = cfg.torch_deterministic if cfg.checkpoint.resume_from: state = fabric.load(cfg.checkpoint.resume_from) diff --git a/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py b/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py index 632c0eb3..a9596ddf 100644 --- a/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py +++ b/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py @@ -34,8 +34,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): device = fabric.device rank = fabric.global_rank world_size = fabric.world_size - fabric.seed_everything(cfg.seed) - torch.backends.cudnn.deterministic = cfg.torch_deterministic ckpt_path = pathlib.Path(cfg.checkpoint.exploration_ckpt_path) resume_from_checkpoint = cfg.checkpoint.resume_from is not None diff --git a/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py b/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py index 1f478241..3d6c073c 100644 --- a/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py +++ b/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py @@ -514,8 +514,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): device = fabric.device rank = fabric.global_rank world_size = fabric.world_size - fabric.seed_everything(cfg.seed) - torch.backends.cudnn.deterministic = cfg.torch_deterministic if cfg.checkpoint.resume_from: state = fabric.load(cfg.checkpoint.resume_from) diff --git a/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py b/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py index 411999dc..c7909c47 100644 --- a/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py +++ b/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py @@ -34,8 +34,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): device = fabric.device rank = fabric.global_rank world_size = fabric.world_size - fabric.seed_everything(cfg.seed) - torch.backends.cudnn.deterministic = cfg.torch_deterministic ckpt_path = pathlib.Path(cfg.checkpoint.exploration_ckpt_path) resume_from_checkpoint = cfg.checkpoint.resume_from is not None diff --git a/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py b/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py index a092eca0..d9751131 100644 --- a/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py +++ b/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py @@ -550,8 +550,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): device = fabric.device rank = fabric.global_rank world_size = fabric.world_size - fabric.seed_everything(cfg.seed) - torch.backends.cudnn.deterministic = cfg.torch_deterministic if cfg.checkpoint.resume_from: state = fabric.load(cfg.checkpoint.resume_from) diff --git a/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py b/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py index a36b2989..f44405e8 100644 --- a/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py +++ b/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py @@ -29,8 +29,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): device = fabric.device rank = fabric.global_rank world_size = fabric.world_size - fabric.seed_everything(cfg.seed) - torch.backends.cudnn.deterministic = cfg.torch_deterministic ckpt_path = pathlib.Path(cfg.checkpoint.exploration_ckpt_path) resume_from_checkpoint = cfg.checkpoint.resume_from is not None diff --git a/sheeprl/algos/ppo/ppo.py b/sheeprl/algos/ppo/ppo.py index 04ddf469..1fa8f187 100644 --- a/sheeprl/algos/ppo/ppo.py +++ b/sheeprl/algos/ppo/ppo.py @@ -119,8 +119,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): rank = fabric.global_rank world_size = fabric.world_size device = fabric.device - fabric.seed_everything(cfg.seed) - torch.backends.cudnn.deterministic = cfg.torch_deterministic # Resume from checkpoint if cfg.checkpoint.resume_from: diff --git a/sheeprl/algos/ppo/ppo_decoupled.py b/sheeprl/algos/ppo/ppo_decoupled.py index f8a81b5f..d03db22f 100644 --- a/sheeprl/algos/ppo/ppo_decoupled.py +++ b/sheeprl/algos/ppo/ppo_decoupled.py @@ -35,8 +35,6 @@ def player( # Initialize the fabric object log_dir = get_log_dir(fabric, cfg.root_dir, cfg.run_name, False) device = fabric.device - fabric.seed_everything(cfg.seed) - torch.backends.cudnn.deterministic = cfg.torch_deterministic # Resume from checkpoint if cfg.checkpoint.resume_from: @@ -377,8 +375,6 @@ def trainer( ) fabric.launch() device = fabric.device - fabric.seed_everything(cfg.seed) - torch.backends.cudnn.deterministic = cfg.torch_deterministic # Resume from checkpoint if cfg.checkpoint.resume_from: diff --git a/sheeprl/algos/ppo_recurrent/ppo_recurrent.py b/sheeprl/algos/ppo_recurrent/ppo_recurrent.py index f643373f..ab71a365 100644 --- a/sheeprl/algos/ppo_recurrent/ppo_recurrent.py +++ b/sheeprl/algos/ppo_recurrent/ppo_recurrent.py @@ -127,8 +127,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): rank = fabric.global_rank world_size = fabric.world_size device = fabric.device - fabric.seed_everything(cfg.seed) - torch.backends.cudnn.deterministic = cfg.torch_deterministic # Resume from checkpoint if cfg.checkpoint.resume_from: diff --git a/sheeprl/algos/sac/sac.py b/sheeprl/algos/sac/sac.py index bde90c4f..52521b1c 100644 --- a/sheeprl/algos/sac/sac.py +++ b/sheeprl/algos/sac/sac.py @@ -91,8 +91,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): device = fabric.device rank = fabric.global_rank world_size = fabric.world_size - fabric.seed_everything(cfg.seed) - torch.backends.cudnn.deterministic = cfg.torch_deterministic # Resume from checkpoint if cfg.checkpoint.resume_from: diff --git a/sheeprl/algos/sac/sac_decoupled.py b/sheeprl/algos/sac/sac_decoupled.py index 08dd9eab..8eacf018 100644 --- a/sheeprl/algos/sac/sac_decoupled.py +++ b/sheeprl/algos/sac/sac_decoupled.py @@ -35,8 +35,6 @@ def player( log_dir = get_log_dir(fabric, cfg.root_dir, cfg.run_name, False) rank = fabric.global_rank device = fabric.device - fabric.seed_everything(cfg.seed) - torch.backends.cudnn.deterministic = cfg.torch_deterministic # Resume from checkpoint if cfg.checkpoint.resume_from: @@ -352,8 +350,6 @@ def trainer( ) fabric.launch() device = fabric.device - fabric.seed_everything(cfg.seed) - torch.backends.cudnn.deterministic = cfg.torch_deterministic # Resume from checkpoint if cfg.checkpoint.resume_from: diff --git a/sheeprl/algos/sac_ae/sac_ae.py b/sheeprl/algos/sac_ae/sac_ae.py index d53f64ea..b59521fb 100644 --- a/sheeprl/algos/sac_ae/sac_ae.py +++ b/sheeprl/algos/sac_ae/sac_ae.py @@ -134,8 +134,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): device = fabric.device rank = fabric.global_rank world_size = fabric.world_size - fabric.seed_everything(cfg.seed) - torch.backends.cudnn.deterministic = cfg.torch_deterministic # Resume from checkpoint if cfg.checkpoint.resume_from: diff --git a/sheeprl/cli.py b/sheeprl/cli.py index e08a3b10..0947a74f 100644 --- a/sheeprl/cli.py +++ b/sheeprl/cli.py @@ -166,7 +166,24 @@ def run_algorithm(cfg: Dict[str, Any]): for k in keys_to_remove: cfg.model_manager.models.pop(k, None) cfg.model_manager.disabled == cfg.model_manager.disabled or len(cfg.model_manager.models) == 0 - fabric.launch(command, cfg, **kwargs) + + # This function is used to make the algorithm reproducible. + # It can be an overkill since Fabric already captures everything we're setting here + # when multiprocessing is used with a `spawn` method (default with DDP strategy). + # https://github.com/Lightning-AI/pytorch-lightning/blob/f23b3b1e7fdab1d325f79f69a28706d33144f27e/src/lightning/fabric/strategies/launchers/multiprocessing.py#L112 + def reproducible(func): + def wrapper(fabric: Fabric, cfg: Dict[str, Any], *args, **kwargs): + if cfg.cublas_workspace_config is not None: + os.environ["CUBLAS_WORKSPACE_CONFIG"] = cfg.cublas_workspace_config + fabric.seed_everything(cfg.seed) + torch.backends.cudnn.benchmark = cfg.torch_backends_cudnn_benchmark + torch.backends.cudnn.deterministic = cfg.torch_backends_cudnn_deterministic + torch.use_deterministic_algorithms(cfg.torch_use_deterministic_algorithms) + return func(fabric, cfg, *args, **kwargs) + + return wrapper + + fabric.launch(reproducible(command), cfg, **kwargs) def eval_algorithm(cfg: DictConfig): diff --git a/sheeprl/configs/config.yaml b/sheeprl/configs/config.yaml index 0a743380..ec841e7d 100644 --- a/sheeprl/configs/config.yaml +++ b/sheeprl/configs/config.yaml @@ -22,7 +22,35 @@ dry_run: False # Reproducibility seed: 42 -torch_deterministic: False + +# For more information about reproducibility in PyTorch, see https://pytorch.org/docs/stable/notes/randomness.html + +# torch.use_deterministic_algorithms() lets you configure PyTorch to use deterministic algorithms +# instead of nondeterministic ones where available, +# and to throw an error if an operation is known to be nondeterministic (and without a deterministic alternative). +torch_use_deterministic_algorithms: False + +# Disabling the benchmarking feature with torch.backends.cudnn.benchmark = False +# causes cuDNN to deterministically select an algorithm, possibly at the cost of reduced performance. +# However, if you do not need reproducibility across multiple executions of your application, +# then performance might improve if the benchmarking feature is enabled with torch.backends.cudnn.benchmark = True. +torch_backends_cudnn_benchmark: True + +# While disabling CUDA convolution benchmarking (discussed above) ensures that CUDA selects the same algorithm each time an application is run, +# that algorithm itself may be nondeterministic, unless either torch.use_deterministic_algorithms(True) +# or torch.backends.cudnn.deterministic = True is set. +# The latter setting controls only this behavior, +# unlike torch.use_deterministic_algorithms() which will make other PyTorch operations behave deterministically, too. +torch_backends_cudnn_deterministic: False + +# From: https://docs.nvidia.com/cuda/cublas/index.html#results-reproducibility +# By design, all cuBLAS API routines from a given toolkit version, generate the same bit-wise results at every run +# when executed on GPUs with the same architecture and the same number of SMs. +# However, bit-wise reproducibility is not guaranteed across toolkit versions +# because the implementation might differ due to some implementation changes. +# This guarantee holds when a single CUDA stream is active only. +# If multiple concurrent streams are active, the library may optimize total performance by picking different internal implementations. +cublas_workspace_config: null # Possible values are: ":4096:8" or ":16:8" # Output folders exp_name: ${algo.name}_${env.id}