diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 2f72622d9dbf..1a03a3448a91 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -1803,6 +1803,7 @@ def _setup_devices(self) -> "torch.device": torch.cuda.set_device(device) elif is_torch_xpu_available() and "ACCELERATE_USE_XPU" not in os.environ: os.environ["ACCELERATE_USE_XPU"] = "true" + self.distributed_state = PartialState(timeout=timedelta(seconds=self.ddp_timeout)) device = torch.device("xpu:0") self._n_gpu = 1 elif is_sagemaker_dp_enabled():