Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

worker_init_fn is not set to seed dataloaders correctly when using DDP #7937

Closed
senarvi opened this issue Jun 11, 2021 · 8 comments · Fixed by #7942
Closed

worker_init_fn is not set to seed dataloaders correctly when using DDP #7937

senarvi opened this issue Jun 11, 2021 · 8 comments · Fixed by #7942
Assignees
Labels
bug Something isn't working help wanted Open to be worked on priority: 0 High priority task
Milestone

Comments

@senarvi
Copy link
Contributor

senarvi commented Jun 11, 2021

🐛 Bug

When seed_everything(workers=True) is called, it will set the environment variable PL_SEED_WORKERS=1. Consequently Trainer will set the worker_init_fn for dataloaders to pl_worker_init_function. It seems to me that worker_init_fn is not set when using DDP. The reason is that DDPPlugin.setup_environment() eventually runs reset_seed(), which reads the value of the PL_GLOBAL_SEED environment value and calls seed_everything() with the default argument workers=False.

Please reproduce using the BoringModel

It's not possible to reproduce the issue in colab, since it doesn't support DDP.

To Reproduce

Here's a simple program that demonstrates the issue:

import os

import torch
from torch.utils.data import DataLoader, Dataset

from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning import seed_everything


class RandomDataset(Dataset):

    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):

    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)

    def test_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("test_loss", loss)

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)


def run():
    seed_everything(1234, workers=True)
    # Sets PL_SEED_WORKERS=1
    print('PL_SEED_WORKERS=' + os.environ['PL_SEED_WORKERS'])

    train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    val_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    test_data = DataLoader(RandomDataset(32, 64), batch_size=2)

    model = BoringModel()
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        limit_train_batches=1,
        limit_val_batches=1,
        num_sanity_val_steps=0,
        max_epochs=1,
        weights_summary=None,
        gpus=2,
        accelerator='ddp'  # Using accelerator='dp' works
    )
    trainer.fit(model, train_dataloader=train_data, val_dataloaders=val_data)
    # Trainer.accelerator.setup_environment() calls DDPPlugin.setup_environment(),
    # which eventually runs seed_everything(workers=False) that sets PL_SEED_WORKERS=0
    # Consequently dataloader.worker_init_fn is not set.
    print('PL_SEED_WORKERS=' + os.environ['PL_SEED_WORKERS'])


if __name__ == '__main__':
    run()

Expected behavior

I would expect pl_worker_init_function to be called. By printing something from the function, I can see that it's called if I use dp accelerator, but not if I use ddp. I can also notice that the environment variable PL_SEED_WORKERS is reset to 0 during the Trainer.fit() call, but I would expect it to have the value 1 in the end.

I think the correct fix would be to make reset_seed() read the PL_SEED_WORKERS environment variable too and pass the corresponding workers argument to seed_everything(). However, I'm not familiar enough with the code to be sure that this is correct.

Preferably pl_worker_init_function would also display a log message that confirms that the workers are seeded correctly.

Environment

  • CUDA:
    • GPU:
      • NVIDIA Tesla V100-SXM2-16GB
      • NVIDIA Tesla V100-SXM2-16GB
    • available: True
    • version: 11.0
  • Packages:
    • numpy: 1.19.2
    • pyTorch_debug: True
    • pyTorch_version: 1.7.0
    • pytorch-lightning: 1.4.0dev
    • tqdm: 4.51.0
  • System:

Additional context

Recently there was discussion about an issue with data loading, where the same NumPy random seed is used across different workers. This causes the workers the use the same random numbers for data transforms. A fix was quickly introduced in PyTorch Lightning that seeds the dataloaders correctly by automatically setting the worker_init_fn for dataloaders.

@senarvi senarvi added bug Something isn't working help wanted Open to be worked on labels Jun 11, 2021
@awaelchli
Copy link
Contributor

awaelchli commented Jun 11, 2021

Thanks for testing this feature out!

In seed_everything we have this line:

os.environ["PL_SEED_WORKERS"] = f"{int(workers)}"

should we change it to

os.environ["PL_SEED_WORKERS"] = os.get("PL_SEED_WORKERS", f"{int(workers)}")

? And if so would you like to test if this change works for you?

@awaelchli awaelchli added this to the v1.3.x milestone Jun 11, 2021
@senarvi
Copy link
Contributor Author

senarvi commented Jun 11, 2021

should we change it to

os.environ["PL_SEED_WORKERS"] = os.get("PL_SEED_WORKERS", f"{int(workers)}")

? And if so would you like to test if this change works for you?

os.environ["PL_SEED_WORKERS"] = os.environ.get("PL_SEED_WORKERS", f"{int(workers)}") works. It means that seed_everything() will ignore the workers argument if PL_SEED_WORKERS is set, but it works too (I already tested).

@awaelchli
Copy link
Contributor

awaelchli commented Jun 11, 2021

It means that seed_everything() will ignore the workers argument if PL_SEED_WORKERS is set,

Maybe a better fix would be to actually change reset_seed() to call seed_everything with workers=bool(os.environ.get("PL_SEED_WORKERS", False))

so that we don't have to ignore the argument when someone does:

seed_everything(123, workers=True)

# .. later on for a second training with different dataloaders maybe??
seed_everything(123, workers=False) # don't ignore this, turn it actually off

@senarvi
Copy link
Contributor Author

senarvi commented Jun 11, 2021

Maybe a better fix would be to actually change reset_seed() to call seed_everything with workers=bool(os.environ.get("PL_SEED_WORKERS", False))

Right. That was my initial thought too. Also, what do you think about writing some log message in pl_worker_init_function() to confirm that the data loaders have been initialized correctly?

@awaelchli
Copy link
Contributor

awaelchli commented Jun 11, 2021

I think it's a good idea in general but note I think we can't do it for every worker, otherwise we get a large output with N * W messages where N is the number of GPUs and W is the number of workers per process. Maybe we can do two things:

  1. only log it with DEBUG level inside the worker_init_fn (so not visible by default unless user sets logging level to debug) and
  2. additionally where we set the worker_init_fn we confirm it was set on the dataloader, this time with regular log level

@awaelchli
Copy link
Contributor

would you be interested in sending a PR for the fix, or the log message, or both? :) happy to help wherever.

@senarvi
Copy link
Contributor Author

senarvi commented Jun 11, 2021

would you be interested in sending a PR for the fix, or the log message, or both? :) happy to help wherever.

I can do that.

@senarvi
Copy link
Contributor Author

senarvi commented Jun 11, 2021

The pull request is created: #7942

Didn't have a chance to run the test suite yet.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Open to be worked on priority: 0 High priority task
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants