From 6a0fcf0607bca59e2d0c69bc9e8bb3512c875b1f Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Mon, 3 Mar 2025 09:05:58 -0500 Subject: [PATCH] Fix loading zero3 weights (#36455) * Check if fixes * Fix zero3 loading * Quality * Fix marc nit * Add fast tests * Migrate to integrations.deepspeed rather than modeling_utils * Style --- src/transformers/integrations/deepspeed.py | 52 ++++++++++++++++++++++ src/transformers/modeling_utils.py | 16 ++++++- tests/deepspeed/test_deepspeed.py | 40 ++++++++++++++++- 3 files changed, 105 insertions(+), 3 deletions(-) diff --git a/src/transformers/integrations/deepspeed.py b/src/transformers/integrations/deepspeed.py index e4742ecd1bfb..1b51a531645d 100644 --- a/src/transformers/integrations/deepspeed.py +++ b/src/transformers/integrations/deepspeed.py @@ -27,6 +27,7 @@ if is_torch_available(): import torch + from torch import nn logger = logging.get_logger(__name__) @@ -305,6 +306,57 @@ def deepspeed_config(): return None +def _load_state_dict_into_zero3_model(model_to_load, state_dict, start_prefix, assign_to_params_buffers=False): + """ + Loads state dict into a model specifically for Zero3, since DeepSpeed does not support the `transformers` + tensor parallelism API. + + Nearly identical code to PyTorch's `_load_from_state_dict` + """ + # copy state_dict so `_load_state_dict_into_zero3_model` can modify it + metadata = getattr(state_dict, "_metadata", None) + state_dict = state_dict.copy() + if metadata is not None: + state_dict._metadata = metadata + + error_msgs = [] + + # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants + # so we need to apply the function recursively. + def load(module: nn.Module, state_dict, prefix="", assign_to_params_buffers=False): + local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) + local_metadata["assign_to_params_buffers"] = assign_to_params_buffers + + args = (state_dict, prefix, local_metadata, True, [], [], error_msgs) + # Parameters of module and children will start with prefix. We can exit early if there are none in this + # state_dict + if is_deepspeed_zero3_enabled() and len([key for key in state_dict if key.startswith(prefix)]) > 0: + import deepspeed + + # In sharded models, each shard has only part of the full state_dict, so only gather + # parameters that are in the current state_dict. + named_parameters = dict(module.named_parameters(prefix=prefix[:-1], recurse=False)) + params_to_gather = [named_parameters[k] for k in state_dict.keys() if k in named_parameters] + if len(params_to_gather) > 0: + # because zero3 puts placeholders in model params, this context + # manager gathers (unpartitions) the params of the current layer, then loads from + # the state dict and then re-partitions them again + with deepspeed.zero.GatheredParameters(params_to_gather, modifier_rank=0): + if torch.distributed.get_rank() == 0: + module._load_from_state_dict(*args) + + for name, child in module._modules.items(): + if child is not None: + load(child, state_dict, prefix + name + ".", assign_to_params_buffers) + + load(model_to_load, state_dict, prefix=start_prefix, assign_to_params_buffers=assign_to_params_buffers) + # Delete `state_dict` so it could be collected by GC earlier. Note that `state_dict` is a copy of the argument, so + # it's safe to delete it. + del state_dict + + return error_msgs + + def deepspeed_optim_sched(trainer, hf_deepspeed_config, args, num_training_steps, model_parameters): """ A convenience wrapper that deals with optimizer and lr scheduler configuration. diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 0554fa9d7d8c..a016f6013f1e 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -50,6 +50,7 @@ from .dynamic_module_utils import custom_object_save from .generation import CompileConfig, GenerationConfig, GenerationMixin from .integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled +from .integrations.deepspeed import _load_state_dict_into_zero3_model from .integrations.flash_attention import flash_attention_forward from .integrations.flex_attention import flex_attention_forward from .integrations.sdpa_attention import sdpa_attention_forward @@ -4918,7 +4919,13 @@ def _load_pretrained_model( mismatched_names = [name for name, _, _ in mismatched_keys] fixed_state_dict = {k: v for k, v in state_dict.items() if k not in mismatched_names} fixed_state_dict = model_to_load._fix_state_dict_keys_on_load(fixed_state_dict) - model_to_load.load_state_dict(fixed_state_dict, strict=False, assign=assign_to_params_buffers) + + if is_deepspeed_zero3_enabled(): + error_msgs += _load_state_dict_into_zero3_model( + model_to_load, fixed_state_dict, start_prefix, assign_to_params_buffers + ) + else: + model_to_load.load_state_dict(fixed_state_dict, strict=False, assign=assign_to_params_buffers) else: # This should always be a list but, just to be sure. if not isinstance(resolved_archive_file, list): @@ -5009,7 +5016,12 @@ def _load_pretrained_model( model_to_load, state_dict, start_prefix ) fixed_state_dict = model_to_load._fix_state_dict_keys_on_load(state_dict) - model_to_load.load_state_dict(fixed_state_dict, strict=False, assign=assign_to_params_buffers) + if is_deepspeed_zero3_enabled(): + error_msgs += _load_state_dict_into_zero3_model( + model_to_load, fixed_state_dict, start_prefix, assign_to_params_buffers + ) + else: + model_to_load.load_state_dict(fixed_state_dict, strict=False, assign=assign_to_params_buffers) # force memory release del state_dict gc.collect() diff --git a/tests/deepspeed/test_deepspeed.py b/tests/deepspeed/test_deepspeed.py index 28ab70059091..80a926f08db5 100644 --- a/tests/deepspeed/test_deepspeed.py +++ b/tests/deepspeed/test_deepspeed.py @@ -170,7 +170,6 @@ def parameterized_custom_name_func(func, param_num, param): @require_deepspeed -@require_torch_accelerator class CoreIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon): """ Testing non-Trainer DeepSpeed integration @@ -194,6 +193,42 @@ def tearDown(self): # reset the ds config global so that tests state doesn't leak unset_hf_deepspeed_config() + def test_init_zero3(self): + # test that zero.Init() works correctly + ds_config = { + "train_batch_size": 1, + "zero_optimization": { + "stage": 3, + }, + } + + dschf = HfDeepSpeedConfig(ds_config) + + self.assertTrue(dschf.is_zero3()) + self.assertTrue(is_deepspeed_zero3_enabled()) + + with LoggingLevel(logging.INFO): + with mockenv_context(**self.dist_env_1_gpu): + logger = logging.get_logger("transformers.modeling_utils") + with CaptureLogger(logger) as cl: + AutoModel.from_pretrained(T5_TINY) + self.assertIn("Detected DeepSpeed ZeRO-3", cl.out) + + # now remove zero optimization + del ds_config["zero_optimization"] + dschf = HfDeepSpeedConfig(ds_config) + + self.assertFalse(dschf.is_zero3()) + self.assertFalse(is_deepspeed_zero3_enabled()) + + with LoggingLevel(logging.INFO): + with mockenv_context(**self.dist_env_1_gpu): + logger = logging.get_logger("transformers.modeling_utils") + with CaptureLogger(logger) as cl: + AutoModel.from_pretrained(T5_TINY) + self.assertNotIn("Detected DeepSpeed ZeRO-3", cl.out) + + @require_torch_accelerator def test_init_zero3_fp16(self): # test that zero.Init() works correctly under zero3/fp16 ds_config = { @@ -201,6 +236,9 @@ def test_init_zero3_fp16(self): "zero_optimization": { "stage": 3, }, + "fp16": { + "enabled": True, + }, } dschf = HfDeepSpeedConfig(ds_config)