Skip to content

Commit

Permalink
Fix loading zero3 weights (huggingface#36455)
Browse files Browse the repository at this point in the history
* Check if fixes

* Fix zero3 loading

* Quality

* Fix marc nit

* Add fast tests

* Migrate to integrations.deepspeed rather than modeling_utils

* Style
  • Loading branch information
muellerzr authored and garrett361 committed Mar 4, 2025
1 parent 4d57c2c commit 91ff42e
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 3 deletions.
52 changes: 52 additions & 0 deletions src/transformers/integrations/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

if is_torch_available():
import torch
from torch import nn


logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -305,6 +306,57 @@ def deepspeed_config():
return None


def _load_state_dict_into_zero3_model(model_to_load, state_dict, start_prefix, assign_to_params_buffers=False):
"""
Loads state dict into a model specifically for Zero3, since DeepSpeed does not support the `transformers`
tensor parallelism API.
Nearly identical code to PyTorch's `_load_from_state_dict`
"""
# copy state_dict so `_load_state_dict_into_zero3_model` can modify it
metadata = getattr(state_dict, "_metadata", None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata

error_msgs = []

# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
# so we need to apply the function recursively.
def load(module: nn.Module, state_dict, prefix="", assign_to_params_buffers=False):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
local_metadata["assign_to_params_buffers"] = assign_to_params_buffers

args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
# Parameters of module and children will start with prefix. We can exit early if there are none in this
# state_dict
if is_deepspeed_zero3_enabled() and len([key for key in state_dict if key.startswith(prefix)]) > 0:
import deepspeed

# In sharded models, each shard has only part of the full state_dict, so only gather
# parameters that are in the current state_dict.
named_parameters = dict(module.named_parameters(prefix=prefix[:-1], recurse=False))
params_to_gather = [named_parameters[k] for k in state_dict.keys() if k in named_parameters]
if len(params_to_gather) > 0:
# because zero3 puts placeholders in model params, this context
# manager gathers (unpartitions) the params of the current layer, then loads from
# the state dict and then re-partitions them again
with deepspeed.zero.GatheredParameters(params_to_gather, modifier_rank=0):
if torch.distributed.get_rank() == 0:
module._load_from_state_dict(*args)

for name, child in module._modules.items():
if child is not None:
load(child, state_dict, prefix + name + ".", assign_to_params_buffers)

load(model_to_load, state_dict, prefix=start_prefix, assign_to_params_buffers=assign_to_params_buffers)
# Delete `state_dict` so it could be collected by GC earlier. Note that `state_dict` is a copy of the argument, so
# it's safe to delete it.
del state_dict

return error_msgs


def deepspeed_optim_sched(trainer, hf_deepspeed_config, args, num_training_steps, model_parameters):
"""
A convenience wrapper that deals with optimizer and lr scheduler configuration.
Expand Down
16 changes: 14 additions & 2 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from .dynamic_module_utils import custom_object_save
from .generation import CompileConfig, GenerationConfig, GenerationMixin
from .integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled
from .integrations.deepspeed import _load_state_dict_into_zero3_model
from .integrations.flash_attention import flash_attention_forward
from .integrations.flex_attention import flex_attention_forward
from .integrations.sdpa_attention import sdpa_attention_forward
Expand Down Expand Up @@ -4918,7 +4919,13 @@ def _load_pretrained_model(
mismatched_names = [name for name, _, _ in mismatched_keys]
fixed_state_dict = {k: v for k, v in state_dict.items() if k not in mismatched_names}
fixed_state_dict = model_to_load._fix_state_dict_keys_on_load(fixed_state_dict)
model_to_load.load_state_dict(fixed_state_dict, strict=False, assign=assign_to_params_buffers)

if is_deepspeed_zero3_enabled():
error_msgs += _load_state_dict_into_zero3_model(
model_to_load, fixed_state_dict, start_prefix, assign_to_params_buffers
)
else:
model_to_load.load_state_dict(fixed_state_dict, strict=False, assign=assign_to_params_buffers)
else:
# This should always be a list but, just to be sure.
if not isinstance(resolved_archive_file, list):
Expand Down Expand Up @@ -5009,7 +5016,12 @@ def _load_pretrained_model(
model_to_load, state_dict, start_prefix
)
fixed_state_dict = model_to_load._fix_state_dict_keys_on_load(state_dict)
model_to_load.load_state_dict(fixed_state_dict, strict=False, assign=assign_to_params_buffers)
if is_deepspeed_zero3_enabled():
error_msgs += _load_state_dict_into_zero3_model(
model_to_load, fixed_state_dict, start_prefix, assign_to_params_buffers
)
else:
model_to_load.load_state_dict(fixed_state_dict, strict=False, assign=assign_to_params_buffers)
# force memory release
del state_dict
gc.collect()
Expand Down
40 changes: 39 additions & 1 deletion tests/deepspeed/test_deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,6 @@ def parameterized_custom_name_func(func, param_num, param):


@require_deepspeed
@require_torch_accelerator
class CoreIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
"""
Testing non-Trainer DeepSpeed integration
Expand All @@ -194,13 +193,52 @@ def tearDown(self):
# reset the ds config global so that tests state doesn't leak
unset_hf_deepspeed_config()

def test_init_zero3(self):
# test that zero.Init() works correctly
ds_config = {
"train_batch_size": 1,
"zero_optimization": {
"stage": 3,
},
}

dschf = HfDeepSpeedConfig(ds_config)

self.assertTrue(dschf.is_zero3())
self.assertTrue(is_deepspeed_zero3_enabled())

with LoggingLevel(logging.INFO):
with mockenv_context(**self.dist_env_1_gpu):
logger = logging.get_logger("transformers.modeling_utils")
with CaptureLogger(logger) as cl:
AutoModel.from_pretrained(T5_TINY)
self.assertIn("Detected DeepSpeed ZeRO-3", cl.out)

# now remove zero optimization
del ds_config["zero_optimization"]
dschf = HfDeepSpeedConfig(ds_config)

self.assertFalse(dschf.is_zero3())
self.assertFalse(is_deepspeed_zero3_enabled())

with LoggingLevel(logging.INFO):
with mockenv_context(**self.dist_env_1_gpu):
logger = logging.get_logger("transformers.modeling_utils")
with CaptureLogger(logger) as cl:
AutoModel.from_pretrained(T5_TINY)
self.assertNotIn("Detected DeepSpeed ZeRO-3", cl.out)

@require_torch_accelerator
def test_init_zero3_fp16(self):
# test that zero.Init() works correctly under zero3/fp16
ds_config = {
"train_batch_size": 1,
"zero_optimization": {
"stage": 3,
},
"fp16": {
"enabled": True,
},
}

dschf = HfDeepSpeedConfig(ds_config)
Expand Down

0 comments on commit 91ff42e

Please sign in to comment.