diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 635166fb4081..f427caee646b 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -473,7 +473,7 @@ def collate_fn(examples): return batch train_dataloader = torch.utils.data.DataLoader( - train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn + train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn, num_workers=1 ) # Scheduler and math around the number of training steps.