diff --git a/CHANGELOG.md b/CHANGELOG.md index 784a1581ee97a..d872ed68df5c0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -211,6 +211,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Fixed setting `worker_init_fn` to seed dataloaders correctly when using DDP ([#7942](https://github.com/PyTorchLightning/pytorch-lightning/pull/7942)) + + - Fixed `DataModule.prepare_data` could only be called on the global rank 0 process ([#7945](https://github.com/PyTorchLightning/pytorch-lightning/pull/7945)) diff --git a/pytorch_lightning/utilities/seed.py b/pytorch_lightning/utilities/seed.py index 51547d5576e74..7c20b7d1b3b1e 100644 --- a/pytorch_lightning/utilities/seed.py +++ b/pytorch_lightning/utilities/seed.py @@ -84,8 +84,9 @@ def reset_seed() -> None: If :func:`pytorch_lightning.utilities.seed.seed_everything` is unused, this function will do nothing. """ seed = os.environ.get("PL_GLOBAL_SEED", None) + workers = os.environ.get("PL_SEED_WORKERS", False) if seed is not None: - seed_everything(int(seed)) + seed_everything(int(seed), workers=bool(workers)) def pl_worker_init_function(worker_id: int, rank: Optional = None) -> None: # pragma: no cover @@ -100,6 +101,9 @@ def pl_worker_init_function(worker_id: int, rank: Optional = None) -> None: # p process_seed = torch.initial_seed() # back out the base seed so we can use all the bits base_seed = process_seed - worker_id + log.debug( + f'Initializing random number generators of process {global_rank} worker {worker_id} with base seed {base_seed}' + ) ss = np.random.SeedSequence([base_seed, worker_id, global_rank]) # use 128 bits (4 x 32-bit words) np.random.seed(ss.generate_state(4))