Skip to content

Commit

Permalink
deepspeed dataloader prepare fix (#1126)
Browse files Browse the repository at this point in the history
  • Loading branch information
pacman100 authored Mar 2, 2023
1 parent 9b5877d commit 075b5d6
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1152,7 +1152,8 @@ def prepare_model(self, model: torch.nn.Module, device_placement=None):
def _prepare_deepspeed(self, *args):
deepspeed_plugin = self.state.deepspeed_plugin

if deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] == "auto":
is_dataloader_present = any(isinstance(obj, torch.utils.data.DataLoader) for obj in args)
if deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] == "auto" or is_dataloader_present:
result = [
self._prepare_one(obj, first_pass=True) if isinstance(obj, torch.utils.data.DataLoader) else obj
for obj in args
Expand Down

0 comments on commit 075b5d6

Please sign in to comment.