diff --git a/vissl/models/trunks/regnet_fsdp.py b/vissl/models/trunks/regnet_fsdp.py index 9fbbb8df9..6cd2e80b7 100644 --- a/vissl/models/trunks/regnet_fsdp.py +++ b/vissl/models/trunks/regnet_fsdp.py @@ -106,7 +106,7 @@ def __init__( bot_mul, group_width, params.se_ratio, - ).cuda() + ) # Init weight before wrapping and sharding. init_weights(block) @@ -127,7 +127,7 @@ class RegNetFSDP(FSDP): """ def __init__(self, model_config: AttrDict, model_name: str): - module = _RegNetFSDP(model_config, model_name).cuda() + module = _RegNetFSDP(model_config, model_name) super().__init__(module, **model_config.FSDP_CONFIG)