From 9d90fa41492d95c2952cee2b05a6f4a34afb1815 Mon Sep 17 00:00:00 2001 From: Matthew Peters Date: Mon, 27 Jul 2020 09:37:02 -0700 Subject: [PATCH] fix checkpointing on TPU --- .../callbacks/model_checkpoint.py | 19 +++++------ pytorch_lightning/trainer/training_io.py | 34 +++++++++++++------ 2 files changed, 32 insertions(+), 21 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 6fee7bdd6cc6b..92d0e1494a408 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -292,12 +292,12 @@ def on_train_start(self, trainer, pl_module): if not gfile.exists(self.dirpath): makedirs(self.dirpath) - @rank_zero_only def on_validation_end(self, trainer, pl_module): - # only run on main process - if trainer.global_rank != 0: - return - + # To get checkpointing working on TPU, need to call _save_model + # for all ranks, to avoid deadlocks. Assuming save_function is mapped + # to trainer.save_checkpoint, this will also work on GPU as save_checkpoint + # handles rank==0 vs rank!=0 logic. If the user provides a custom + # save_function, they are responsible for adding rank==0 vs rank!=0 logic. metrics = trainer.callback_metrics epoch = trainer.current_epoch @@ -348,8 +348,6 @@ def on_validation_end(self, trainer, pl_module): else: if self.verbose > 0: log.info(f'\nEpoch {epoch:05d}: saving model to {filepath}') - - assert trainer.global_rank == 0, 'tried to make a checkpoint from non global_rank=0' self._save_model(filepath, trainer, pl_module) if self.save_last: @@ -384,6 +382,7 @@ def _do_check_save(self, filepath, current, epoch, trainer, pl_module): f' {filepath} as top {self.save_top_k}') self._save_model(filepath, trainer, pl_module) - for cur_path in del_list: - if cur_path != filepath: - self._del_model(cur_path) + if trainer.global_rank == 0: + for cur_path in del_list: + if cur_path != filepath: + self._del_model(cur_path) diff --git a/pytorch_lightning/trainer/training_io.py b/pytorch_lightning/trainer/training_io.py index 7a1613b919a26..0d1dbec9b77f3 100644 --- a/pytorch_lightning/trainer/training_io.py +++ b/pytorch_lightning/trainer/training_io.py @@ -270,29 +270,41 @@ def _atomic_save(self, checkpoint, filepath: str): This points to the file that the checkpoint will be stored in. """ tmp_path = str(filepath) + ".part" + if self.use_tpu: + xm.save(checkpoint, tmp_path, master_only=True, global_master=True) + if xm.is_master_ordinal(local=False): + os.replace(tmp_path, filepath) # Can't use the new zipfile serialization for 1.6.0 because there's a bug in # torch.hub.load_state_dict_from_url() that prevents it from loading the new files. # More details can be found here: https://github.com/pytorch/pytorch/issues/42239 - if LooseVersion(torch.__version__).version[:3] == [1, 6, 0]: + elif LooseVersion(torch.__version__).version[:3] == [1, 6, 0]: torch.save(checkpoint, tmp_path, _use_new_zipfile_serialization=False) else: torch.save(checkpoint, tmp_path) os.replace(tmp_path, filepath) def save_checkpoint(self, filepath, weights_only: bool = False): - checkpoint = self.dump_checkpoint(weights_only) - - if self.is_global_zero: + def _do_save(chkpt): # do the actual save try: - self._atomic_save(checkpoint, filepath) + self._atomic_save(chkpt, filepath) except AttributeError as err: - if LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint: - del checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] - rank_zero_warn( - 'Warning, `module_arguments` dropped from checkpoint.' f' An attribute is not picklable {err}' - ) - self._atomic_save(checkpoint, filepath) + if LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in chkpt: + del chkpt[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] + rank_zero_warn('Warning, `module_arguments` dropped from checkpoint.' + f' An attribute is not picklable {err}') + self._atomic_save(chkpt, filepath) + + checkpoint = self.dump_checkpoint(weights_only) + + # self._atomic_save has different behavior for XLA vs + # non-XLA. In XLA, it has a barrier and internal logic to only + # save for rank==0, so need to call for all ranks. For non-XLA, + # it doesn't have rank==0 logic so only call for rank==0 + if self.use_tpu: + _do_save(checkpoint) + elif self.is_global_zero: + _do_save(checkpoint) def restore(self, checkpoint_path: str, on_gpu: bool): """