-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] fix checkpointing on TPU #2726
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -292,12 +292,12 @@ def on_train_start(self, trainer, pl_module): | |
if not gfile.exists(self.dirpath): | ||
makedirs(self.dirpath) | ||
|
||
@rank_zero_only | ||
def on_validation_end(self, trainer, pl_module): | ||
# only run on main process | ||
if trainer.global_rank != 0: | ||
return | ||
|
||
# To get checkpointing working on TPU, need to call _save_model | ||
# for all ranks, to avoid deadlocks. Assuming save_function is mapped | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: without |
||
# to trainer.save_checkpoint, this will also work on GPU as save_checkpoint | ||
# handles rank==0 vs rank!=0 logic. If the user provides a custom | ||
# save_function, they are responsible for adding rank==0 vs rank!=0 logic. | ||
metrics = trainer.callback_metrics | ||
epoch = trainer.current_epoch | ||
|
||
|
@@ -348,8 +348,6 @@ def on_validation_end(self, trainer, pl_module): | |
else: | ||
if self.verbose > 0: | ||
log.info(f'\nEpoch {epoch:05d}: saving model to {filepath}') | ||
|
||
assert trainer.global_rank == 0, 'tried to make a checkpoint from non global_rank=0' | ||
self._save_model(filepath, trainer, pl_module) | ||
|
||
if self.save_last: | ||
|
@@ -384,6 +382,7 @@ def _do_check_save(self, filepath, current, epoch, trainer, pl_module): | |
f' {filepath} as top {self.save_top_k}') | ||
self._save_model(filepath, trainer, pl_module) | ||
|
||
for cur_path in del_list: | ||
if cur_path != filepath: | ||
self._del_model(cur_path) | ||
if trainer.global_rank == 0: | ||
for cur_path in del_list: | ||
if cur_path != filepath: | ||
self._del_model(cur_path) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -270,29 +270,41 @@ def _atomic_save(self, checkpoint, filepath: str): | |
This points to the file that the checkpoint will be stored in. | ||
""" | ||
tmp_path = str(filepath) + ".part" | ||
if self.use_tpu: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is |
||
xm.save(checkpoint, tmp_path, master_only=True, global_master=True) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe add a barrier before There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the barrier is already inside There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Excuse my naiveness, but shouldn't there be one before save, to make sure that the weights have been updated by every process? |
||
if xm.is_master_ordinal(local=False): | ||
os.replace(tmp_path, filepath) | ||
# Can't use the new zipfile serialization for 1.6.0 because there's a bug in | ||
# torch.hub.load_state_dict_from_url() that prevents it from loading the new files. | ||
# More details can be found here: https://github.com/pytorch/pytorch/issues/42239 | ||
if LooseVersion(torch.__version__).version[:3] == [1, 6, 0]: | ||
elif LooseVersion(torch.__version__).version[:3] == [1, 6, 0]: | ||
torch.save(checkpoint, tmp_path, _use_new_zipfile_serialization=False) | ||
else: | ||
torch.save(checkpoint, tmp_path) | ||
os.replace(tmp_path, filepath) | ||
|
||
def save_checkpoint(self, filepath, weights_only: bool = False): | ||
checkpoint = self.dump_checkpoint(weights_only) | ||
|
||
if self.is_global_zero: | ||
def _do_save(chkpt): | ||
# do the actual save | ||
try: | ||
self._atomic_save(checkpoint, filepath) | ||
self._atomic_save(chkpt, filepath) | ||
except AttributeError as err: | ||
if LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint: | ||
del checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] | ||
rank_zero_warn( | ||
'Warning, `module_arguments` dropped from checkpoint.' f' An attribute is not picklable {err}' | ||
) | ||
self._atomic_save(checkpoint, filepath) | ||
if LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in chkpt: | ||
del chkpt[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] | ||
rank_zero_warn('Warning, `module_arguments` dropped from checkpoint.' | ||
f' An attribute is not picklable {err}') | ||
self._atomic_save(chkpt, filepath) | ||
|
||
checkpoint = self.dump_checkpoint(weights_only) | ||
|
||
# self._atomic_save has different behavior for XLA vs | ||
# non-XLA. In XLA, it has a barrier and internal logic to only | ||
# save for rank==0, so need to call for all ranks. For non-XLA, | ||
# it doesn't have rank==0 logic so only call for rank==0 | ||
if self.use_tpu: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not or? |
||
_do_save(checkpoint) | ||
elif self.is_global_zero: | ||
_do_save(checkpoint) | ||
|
||
def restore(self, checkpoint_path: str, on_gpu: bool): | ||
""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This may break DDp