-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Tpu save #4309
Conversation
Codecov Report
@@ Coverage Diff @@
## master #4309 +/- ##
======================================
- Coverage 93% 93% -0%
======================================
Files 124 124
Lines 9320 9303 -17
======================================
- Hits 8640 8616 -24
- Misses 680 687 +7 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@lezwon does that move to cpu really belong in this file?
Why is this needed for TPUs?
You realize this will affect more than just TPUs no? this is why i don't like making these changes outside the accelerator.
If it turns out this is the correct place for this, please add a on_save() function to each accelerator. Then make all of them no op:
def on_save(self, ...):
pass
And add the change ONLY to the TPU accelerator.
def on_save(self, ...):
your_changes()
Going forward, any changes that are accelerator specific should not be done inside methods like this. Instead, each accelerator needs to implement this method and then called like:
self.accelerator.on_save()
The reason is that we are trying to break up all the underlying accelerator code so they are easier to debug and changes to an accelerator won't break all the others.
The XLA guide recommends users to move the tensors to CPU before saving so that they can be loaded on non-TPU devices. Lightning users have faced this issue wherein they have trained a model on a TPU and arent able to use it on a GPU or CPU. Hence we need to move these tensors to CPU before saving. I get your point about separating such code and making it accelerator specific. Will refactor this into |
This pull request is now in conflict... :( |
hey @lezwon any activity here? were you able to reproduce this using the boring model btw? Don't mind picking this up |
hey @SeanNaren, I've been making some refactors to make this change TPU specific. WIll push an updated branch soon :) |
This reverts commit 0c9316b
removed barrier for tpu during test reduced epochs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
from pytorch_lightning.core.grads import GradInformation | ||
from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks, ModelHooks | ||
from pytorch_lightning.core.memory import ModelSummary | ||
from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, PRIMITIVE_TYPES, ModelIO |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
where did this go? cc @tchaton
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess useless imports.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shall not be there at all in the first place, @SeanNaren
from pytorch_lightning.core.grads import GradInformation | ||
from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks, ModelHooks | ||
from pytorch_lightning.core.memory import ModelSummary | ||
from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, PRIMITIVE_TYPES, ModelIO |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess useless imports.
What does this PR do?
Fixes #2700
fixes #2303
fixes #3660
accelerator.barrier()
on TPU when calling.test()
as the multiprocessing begins only in.fit()
Related to #3044
Before submitting
PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.
Did you have fun?
Make sure you had fun coding 🙃