Skip to content

Commit

Permalink
address issue 2 in #83
Browse files Browse the repository at this point in the history
  • Loading branch information
fabianlim committed Oct 14, 2024
1 parent a50ff63 commit 1f8cc16
Showing 1 changed file with 11 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,19 @@ def _all_reduce_hook(grad):

# because we will ignore these from FSDP, we need to manually
# move them to gpu if they are already not on them
# - if the adapters are on meta, we assume that this is for FSDP
# low_cpu_mem_mode purposes, and that the values will be synced over
# - So just initialize them to empty.
if not A.weight.is_cuda:
set_module_tensor_to_device(A, "weight", "cuda")
value = None
if A.weight.device == torch.device('meta'):
value = torch.empty(*A.weight.size(), dtype=A.weight.dtype)
set_module_tensor_to_device(A, "weight", "cuda", value)
if not B.weight.is_cuda:
set_module_tensor_to_device(B, "weight", "cuda")
value = None
if B.weight.device == torch.device('meta'):
value = torch.empty(*B.weight.size(), dtype=B.weight.dtype)
set_module_tensor_to_device(B, "weight", "cuda", value)


def register_foak_model_patch_rules(base_type):
Expand Down

0 comments on commit 1f8cc16

Please sign in to comment.