Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix loading zero3 weights #36455

Merged
merged 9 commits into from
Mar 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions src/transformers/integrations/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

if is_torch_available():
import torch
from torch import nn


logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -305,6 +306,57 @@ def deepspeed_config():
return None


def _load_state_dict_into_zero3_model(model_to_load, state_dict, start_prefix, assign_to_params_buffers=False):
"""
Loads state dict into a model specifically for Zero3, since DeepSpeed does not support the `transformers`
tensor parallelism API.

Nearly identical code to PyTorch's `_load_from_state_dict`
"""
# copy state_dict so `_load_state_dict_into_zero3_model` can modify it
metadata = getattr(state_dict, "_metadata", None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata

error_msgs = []

# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
# so we need to apply the function recursively.
def load(module: nn.Module, state_dict, prefix="", assign_to_params_buffers=False):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
local_metadata["assign_to_params_buffers"] = assign_to_params_buffers

args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
# Parameters of module and children will start with prefix. We can exit early if there are none in this
# state_dict
if is_deepspeed_zero3_enabled() and len([key for key in state_dict if key.startswith(prefix)]) > 0:
import deepspeed

# In sharded models, each shard has only part of the full state_dict, so only gather
# parameters that are in the current state_dict.
named_parameters = dict(module.named_parameters(prefix=prefix[:-1], recurse=False))
params_to_gather = [named_parameters[k] for k in state_dict.keys() if k in named_parameters]
if len(params_to_gather) > 0:
# because zero3 puts placeholders in model params, this context
# manager gathers (unpartitions) the params of the current layer, then loads from
# the state dict and then re-partitions them again
with deepspeed.zero.GatheredParameters(params_to_gather, modifier_rank=0):
if torch.distributed.get_rank() == 0:
module._load_from_state_dict(*args)

for name, child in module._modules.items():
if child is not None:
load(child, state_dict, prefix + name + ".", assign_to_params_buffers)

load(model_to_load, state_dict, prefix=start_prefix, assign_to_params_buffers=assign_to_params_buffers)
# Delete `state_dict` so it could be collected by GC earlier. Note that `state_dict` is a copy of the argument, so
# it's safe to delete it.
del state_dict

return error_msgs


def deepspeed_optim_sched(trainer, hf_deepspeed_config, args, num_training_steps, model_parameters):
"""
A convenience wrapper that deals with optimizer and lr scheduler configuration.
Expand Down
16 changes: 14 additions & 2 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from .dynamic_module_utils import custom_object_save
from .generation import CompileConfig, GenerationConfig, GenerationMixin
from .integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled
from .integrations.deepspeed import _load_state_dict_into_zero3_model
from .integrations.flash_attention import flash_attention_forward
from .integrations.flex_attention import flex_attention_forward
from .integrations.sdpa_attention import sdpa_attention_forward
Expand Down Expand Up @@ -4918,7 +4919,13 @@ def _load_pretrained_model(
mismatched_names = [name for name, _, _ in mismatched_keys]
fixed_state_dict = {k: v for k, v in state_dict.items() if k not in mismatched_names}
fixed_state_dict = model_to_load._fix_state_dict_keys_on_load(fixed_state_dict)
model_to_load.load_state_dict(fixed_state_dict, strict=False, assign=assign_to_params_buffers)

if is_deepspeed_zero3_enabled():
error_msgs += _load_state_dict_into_zero3_model(
model_to_load, fixed_state_dict, start_prefix, assign_to_params_buffers
)
else:
model_to_load.load_state_dict(fixed_state_dict, strict=False, assign=assign_to_params_buffers)
else:
# This should always be a list but, just to be sure.
if not isinstance(resolved_archive_file, list):
Expand Down Expand Up @@ -5009,7 +5016,12 @@ def _load_pretrained_model(
model_to_load, state_dict, start_prefix
)
fixed_state_dict = model_to_load._fix_state_dict_keys_on_load(state_dict)
model_to_load.load_state_dict(fixed_state_dict, strict=False, assign=assign_to_params_buffers)
if is_deepspeed_zero3_enabled():
error_msgs += _load_state_dict_into_zero3_model(
model_to_load, fixed_state_dict, start_prefix, assign_to_params_buffers
)
else:
model_to_load.load_state_dict(fixed_state_dict, strict=False, assign=assign_to_params_buffers)
# force memory release
del state_dict
gc.collect()
Expand Down
40 changes: 39 additions & 1 deletion tests/deepspeed/test_deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,6 @@ def parameterized_custom_name_func(func, param_num, param):


@require_deepspeed
@require_torch_accelerator
class CoreIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
"""
Testing non-Trainer DeepSpeed integration
Expand All @@ -194,13 +193,52 @@ def tearDown(self):
# reset the ds config global so that tests state doesn't leak
unset_hf_deepspeed_config()

def test_init_zero3(self):
# test that zero.Init() works correctly
ds_config = {
"train_batch_size": 1,
"zero_optimization": {
"stage": 3,
},
}

dschf = HfDeepSpeedConfig(ds_config)

self.assertTrue(dschf.is_zero3())
self.assertTrue(is_deepspeed_zero3_enabled())

with LoggingLevel(logging.INFO):
with mockenv_context(**self.dist_env_1_gpu):
logger = logging.get_logger("transformers.modeling_utils")
with CaptureLogger(logger) as cl:
AutoModel.from_pretrained(T5_TINY)
self.assertIn("Detected DeepSpeed ZeRO-3", cl.out)

# now remove zero optimization
del ds_config["zero_optimization"]
dschf = HfDeepSpeedConfig(ds_config)

self.assertFalse(dschf.is_zero3())
self.assertFalse(is_deepspeed_zero3_enabled())

with LoggingLevel(logging.INFO):
with mockenv_context(**self.dist_env_1_gpu):
logger = logging.get_logger("transformers.modeling_utils")
with CaptureLogger(logger) as cl:
AutoModel.from_pretrained(T5_TINY)
self.assertNotIn("Detected DeepSpeed ZeRO-3", cl.out)

@require_torch_accelerator
def test_init_zero3_fp16(self):
# test that zero.Init() works correctly under zero3/fp16
ds_config = {
"train_batch_size": 1,
"zero_optimization": {
"stage": 3,
},
"fp16": {
"enabled": True,
},
}

dschf = HfDeepSpeedConfig(ds_config)
Expand Down