From 4bf13a0cbd058465496d5866eec5ed2e6a4f260c Mon Sep 17 00:00:00 2001 From: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com> Date: Tue, 28 Feb 2023 18:40:33 +0530 Subject: [PATCH] deepspeed dataloader prepare fix --- src/accelerate/accelerator.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index 152efd6ccd0..dcbdfccb0eb 100644 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -1151,7 +1151,8 @@ def prepare_model(self, model: torch.nn.Module, device_placement=None): def _prepare_deepspeed(self, *args): deepspeed_plugin = self.state.deepspeed_plugin - if deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] == "auto": + is_dataloader_present = any(isinstance(obj, torch.utils.data.DataLoader) for obj in args) + if deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] == "auto" or is_dataloader_present: result = [ self._prepare_one(obj, first_pass=True) if isinstance(obj, torch.utils.data.DataLoader) else obj for obj in args