Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix loading zero3 weights #36455

Merged
merged 9 commits into from
Mar 3, 2025
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 59 additions & 1 deletion src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,6 +683,57 @@ def _find_disjoint(tensors: List[Set[str]], state_dict: Dict[str, torch.Tensor])
return shared_tensors, disjoint_tensors


def _load_state_dict_into_zero3_model(model_to_load, state_dict, start_prefix, assign_to_params_buffers=False):
"""
Loads state dict into a model specifically for Zero3, since DeepSpeed does not support the `transformers`
tensor parallelism API.

Nearly identical code to PyTorch's `_load_from_state_dict`
"""
# copy state_dict so `_load_state_dict_into_zero3_model` can modify it
metadata = getattr(state_dict, "_metadata", None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata

error_msgs = []

# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
# so we need to apply the function recursively.
def load(module: nn.Module, state_dict, prefix="", assign_to_params_buffers=False):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
local_metadata["assign_to_params_buffers"] = assign_to_params_buffers

args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
# Parameters of module and children will start with prefix. We can exit early if there are none in this
# state_dict
if is_deepspeed_zero3_enabled() and len([key for key in state_dict if key.startswith(prefix)]) > 0:
import deepspeed

# In sharded models, each shard has only part of the full state_dict, so only gather
# parameters that are in the current state_dict.
named_parameters = dict(module.named_parameters(prefix=prefix[:-1], recurse=False))
params_to_gather = [named_parameters[k] for k in state_dict.keys() if k in named_parameters]
if len(params_to_gather) > 0:
# because zero3 puts placeholders in model params, this context
# manager gathers (unpartitions) the params of the current layer, then loads from
# the state dict and then re-partitions them again
with deepspeed.zero.GatheredParameters(params_to_gather, modifier_rank=0):
if torch.distributed.get_rank() == 0:
module._load_from_state_dict(*args)

for name, child in module._modules.items():
if child is not None:
load(child, state_dict, prefix + name + ".", assign_to_params_buffers)

load(model_to_load, state_dict, prefix=start_prefix, assign_to_params_buffers=assign_to_params_buffers)
# Delete `state_dict` so it could be collected by GC earlier. Note that `state_dict` is a copy of the argument, so
# it's safe to delete it.
del state_dict

return error_msgs


def _find_identical(tensors: List[Set[str]], state_dict: Dict[str, torch.Tensor]) -> Tuple[List[Set[str]], Set[str]]:
shared_tensors = []
identical = []
Expand Down Expand Up @@ -4908,7 +4959,12 @@ def _load_pretrained_model(
)
# at this point the state dict should be on cpu, we don't need to actually read it
fixed_state_dict = model_to_load._fix_state_dict_keys_on_load(state_dict)
model_to_load.load_state_dict(fixed_state_dict, strict=False, assign=assign_to_params_buffers)
if is_deepspeed_zero3_enabled():
error_msgs += _load_state_dict_into_zero3_model(
model_to_load, fixed_state_dict, start_prefix, assign_to_params_buffers
)
else:
model_to_load.load_state_dict(fixed_state_dict, strict=False, assign=assign_to_params_buffers)
else:
# This should always be a list but, just to be sure.
if not isinstance(resolved_archive_file, list):
Expand Down Expand Up @@ -4991,6 +5047,8 @@ def _load_pretrained_model(
shard_file=shard_file,
)
error_msgs += new_error_msgs
elif is_deepspeed_zero3_enabled():
fixed_state_dict = cls._fix_state_dict_keys_on_load(state_dict)
else:
state_dict = load_state_dict(shard_file, map_location="cpu", weights_only=weights_only)
# Sharded checkpoint or whole but low_cpu_mem_usage==True
Expand Down
Loading