Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] fix checkpointing on TPU #2726

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 9 additions & 10 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,12 +292,12 @@ def on_train_start(self, trainer, pl_module):
if not gfile.exists(self.dirpath):
makedirs(self.dirpath)

@rank_zero_only
def on_validation_end(self, trainer, pl_module):
# only run on main process
if trainer.global_rank != 0:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This may break DDp

return

# To get checkpointing working on TPU, need to call _save_model
# for all ranks, to avoid deadlocks. Assuming save_function is mapped
Copy link
Contributor

@ibeltagy ibeltagy Jul 27, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: without if trainer.global_rank != 0: and @rank_zero_only, all threads are writing to the log making it a little messy.

# to trainer.save_checkpoint, this will also work on GPU as save_checkpoint
# handles rank==0 vs rank!=0 logic. If the user provides a custom
# save_function, they are responsible for adding rank==0 vs rank!=0 logic.
metrics = trainer.callback_metrics
epoch = trainer.current_epoch

Expand Down Expand Up @@ -348,8 +348,6 @@ def on_validation_end(self, trainer, pl_module):
else:
if self.verbose > 0:
log.info(f'\nEpoch {epoch:05d}: saving model to {filepath}')

assert trainer.global_rank == 0, 'tried to make a checkpoint from non global_rank=0'
self._save_model(filepath, trainer, pl_module)

if self.save_last:
Expand Down Expand Up @@ -384,6 +382,7 @@ def _do_check_save(self, filepath, current, epoch, trainer, pl_module):
f' {filepath} as top {self.save_top_k}')
self._save_model(filepath, trainer, pl_module)

for cur_path in del_list:
if cur_path != filepath:
self._del_model(cur_path)
if trainer.global_rank == 0:
for cur_path in del_list:
if cur_path != filepath:
self._del_model(cur_path)
34 changes: 23 additions & 11 deletions pytorch_lightning/trainer/training_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,29 +270,41 @@ def _atomic_save(self, checkpoint, filepath: str):
This points to the file that the checkpoint will be stored in.
"""
tmp_path = str(filepath) + ".part"
if self.use_tpu:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is on_tpu

xm.save(checkpoint, tmp_path, master_only=True, global_master=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add a barrier before xm.save, to make sure all processes are in sync?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the barrier is already inside xm.save here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Excuse my naiveness, but shouldn't there be one before save, to make sure that the weights have been updated by every process?

if xm.is_master_ordinal(local=False):
os.replace(tmp_path, filepath)
# Can't use the new zipfile serialization for 1.6.0 because there's a bug in
# torch.hub.load_state_dict_from_url() that prevents it from loading the new files.
# More details can be found here: https://github.com/pytorch/pytorch/issues/42239
if LooseVersion(torch.__version__).version[:3] == [1, 6, 0]:
elif LooseVersion(torch.__version__).version[:3] == [1, 6, 0]:
torch.save(checkpoint, tmp_path, _use_new_zipfile_serialization=False)
else:
torch.save(checkpoint, tmp_path)
os.replace(tmp_path, filepath)

def save_checkpoint(self, filepath, weights_only: bool = False):
checkpoint = self.dump_checkpoint(weights_only)

if self.is_global_zero:
def _do_save(chkpt):
# do the actual save
try:
self._atomic_save(checkpoint, filepath)
self._atomic_save(chkpt, filepath)
except AttributeError as err:
if LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint:
del checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]
rank_zero_warn(
'Warning, `module_arguments` dropped from checkpoint.' f' An attribute is not picklable {err}'
)
self._atomic_save(checkpoint, filepath)
if LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in chkpt:
del chkpt[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]
rank_zero_warn('Warning, `module_arguments` dropped from checkpoint.'
f' An attribute is not picklable {err}')
self._atomic_save(chkpt, filepath)

checkpoint = self.dump_checkpoint(weights_only)

# self._atomic_save has different behavior for XLA vs
# non-XLA. In XLA, it has a barrier and internal logic to only
# save for rank==0, so need to call for all ranks. For non-XLA,
# it doesn't have rank==0 logic so only call for rank==0
if self.use_tpu:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not or?

_do_save(checkpoint)
elif self.is_global_zero:
_do_save(checkpoint)

def restore(self, checkpoint_path: str, on_gpu: bool):
"""
Expand Down