diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index 152efd6ccd0..dcbdfccb0eb 100644 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -1151,7 +1151,8 @@ def prepare_model(self, model: torch.nn.Module, device_placement=None): def _prepare_deepspeed(self, *args): deepspeed_plugin = self.state.deepspeed_plugin - if deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] == "auto": + is_dataloader_present = any(isinstance(obj, torch.utils.data.DataLoader) for obj in args) + if deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] == "auto" or is_dataloader_present: result = [ self._prepare_one(obj, first_pass=True) if isinstance(obj, torch.utils.data.DataLoader) else obj for obj in args