Skip to content

Commit

Permalink
Add resuming from specific checkpoint (Lightning-AI#516)
Browse files Browse the repository at this point in the history
* Add resume_from_checkpoint

* Fix variable name

* Lightning-AI#515 Remove did_restore

* Lightning-AI#515 Simplify code

* Lightning-AI#515 Update doc for resume_from_checkpoint

* Lightning-AI#515 Add on_gpu
  • Loading branch information
dreamgonfly authored and williamFalcon committed Nov 30, 2019
1 parent df7b6d9 commit 2b8475f
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 3 deletions.
4 changes: 3 additions & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ def __init__(self,
weights_save_path=None,
amp_level='O1',
nb_sanity_val_steps=5,
truncated_bptt_steps=None):
truncated_bptt_steps=None,
resume_from_checkpoint=None):
"""
:param logger: Logger for experiment tracking
Expand Down Expand Up @@ -140,6 +141,7 @@ def __init__(self,
self.nb_sanity_val_steps = nb_sanity_val_steps
self.print_nan_grads = print_nan_grads
self.truncated_bptt_steps = truncated_bptt_steps
self.resume_from_checkpoint = resume_from_checkpoint
self.shown_warnings = set()

self.fast_dev_run = fast_dev_run
Expand Down
7 changes: 5 additions & 2 deletions pytorch_lightning/trainer/trainer_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,11 @@ def restore_weights(self, model):
torch.cuda.empty_cache()

if not did_restore_hpc_weights:
# restore weights if same exp version
self.restore_state_if_checkpoint_exists(model)
if self.resume_from_checkpoint is not None:
self.restore(self.resume_from_checkpoint, on_gpu=self.on_gpu)
else:
# restore weights if same exp version
self.restore_state_if_checkpoint_exists(model)

# wait for all models to restore weights
if self.use_ddp or self.use_ddp2:
Expand Down

0 comments on commit 2b8475f

Please sign in to comment.