-
Notifications
You must be signed in to change notification settings - Fork 281
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[fix][FSDP] fix weight init when using apply() (fixes #490 and #444) #543
Conversation
46f422f
to
a31d5dd
Compare
Doesn't work with SyncBN, will fix |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, nice tests!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Super nice. Some nonblocking comments.
The test failures might be related to the cast_buffer change? |
It's very weird, tests only seem to fail on 1.6, but pass in 1.7.1 and 1.8 (and on my local machine)... will dig a bit |
Summary: Before this PR (facebookresearch/fairscale#543) was merged, we used to need the extra cuda() calls. Now, they are not needed. Unfortunately, this doesn't solve the long model init time issue we have. A FSDP model init still take >20 mins for me. This is really bad for debugging the regnet128 conv layer crash problem I am debugging. The following debugging output shows that most delays are in FSDP wrapping, some in BN wrapping and some in the layer wrapping. ``` INFO 2021-04-14 12:18:35,883 regnet_2.py: 159: block created INFO 2021-04-14 12:18:35,884 regnet_2.py: 161: cpu INFO 2021-04-14 12:18:35,884 regnet_2.py: 161: cpu INFO 2021-04-14 12:18:35,884 regnet_2.py: 161: cpu INFO 2021-04-14 12:18:35,884 regnet_2.py: 161: cpu INFO 2021-04-14 12:18:35,884 regnet_2.py: 161: cpu INFO 2021-04-14 12:18:35,884 regnet_2.py: 161: cpu INFO 2021-04-14 12:18:35,884 regnet_2.py: 161: cpu INFO 2021-04-14 12:18:35,884 regnet_2.py: 161: cpu INFO 2021-04-14 12:18:35,884 regnet_2.py: 161: cpu INFO 2021-04-14 12:18:35,884 regnet_2.py: 161: cpu INFO 2021-04-14 12:18:35,884 regnet_2.py: 161: cpu INFO 2021-04-14 12:18:35,884 regnet_2.py: 161: cpu INFO 2021-04-14 12:18:35,884 regnet_2.py: 161: cpu INFO 2021-04-14 12:19:07,388 regnet_2.py: 163: block bn wrapped INFO 2021-04-14 12:19:18,388 regnet_2.py: 166: block wrapped ``` In any case, this PR is pretty safe and should go in so that we don't need to do an extra `cuda()` call before wrapping. Pull Request resolved: fairinternal/ssl_scaling#75 Reviewed By: prigoyal Differential Revision: D27776285 Pulled By: min-xu-ai fbshipit-source-id: 3e43c6fe750fd6ee35933400b03a069d62040d8a
compute_device
so that we can usesummon_full_params
immediately afterFSDP.__init__
, even if the params are still on CPU (this also fixes [FSDP] improve robustness to mismatch between torch.cuda.current_device and model's device #444)FSDP.apply
so that it callssummon_full_params
first. This makes it possible to do weight inits viamodel.apply(custom_weight_init_fn)
without segfaulting, and should give identical results to not using FSDPThis also required reworking
_all_buffers_to
to no longer useapply
, since it is called from within_lazy_init
and created some circular logic:apply -> summon_full_params -> _lazy_init -> _all_buffers_to -> apply
.