-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Move parameter validation specific to TPU Training plugins #7415
Conversation
Codecov Report
@@ Coverage Diff @@
## master #7415 +/- ##
=======================================
- Coverage 93% 88% -5%
=======================================
Files 200 200
Lines 12962 12966 +4
=======================================
- Hits 11998 11377 -621
- Misses 964 1589 +625 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM !
@kaushikb11 the TPU tests are being skipped. Probably the TPU device is not being detected :) |
|
||
# model = Model() | ||
# trainer = Trainer(checkpoint_callback=True, max_epochs=1, tpu_cores=1) | ||
def on_post_move_to_device(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we make this check slightly smarter but checking parameters names ?
If I do self.layer_3.weight = self.layer_1.weight in the init function and mess up and do self.layer_3.weight = self.layer_2.weight, I won't get a warning but tying is different. Ideally it would be great to explicitly tell which weights are shared or do it automatically for the user.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting, will follow up.
@@ -171,6 +172,7 @@ def new_process(self, process_idx: int, trainer, mp_queue) -> None: | |||
if self.global_rank == 0: | |||
time.sleep(2) | |||
|
|||
@parameter_validation |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how sow is this?
@@ -71,12 +71,11 @@ def auto_transfer_args(self, *args, **kwargs): | |||
|
|||
def parameter_validation(fn: Callable) -> Callable: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think now that you changed the decorator target to self.model, this decorator may no longer fit very well into core/decorators because it is basically now specific to the plugin having the attribute self.model.
What do you think about moving it?
Just for consideration
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 to @awaelchli 's suggestion
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch. Will do a follow-up PR for this.
c6cc85f
to
c32227a
Compare
* Move parameter validation specific to TPU Training plugins * update docstring
* Move parameter validation specific to TPU Training plugins * update docstring
* Move parameter validation specific to TPU Training plugins * update docstring
What does this PR do?
Follow up to #5441
Before submitting
PR review
Anyone in the community is free to review the PR once the tests have passed.
Before you start reviewing make sure you have read Review guidelines. In short, see the following bullet-list:
Did you have fun?
Make sure you had fun coding 🙃