From 05f3a79bcca3e121bce70f24774083348eb905da Mon Sep 17 00:00:00 2001 From: "[[ -z $EMAIL ]] && read -e -p \"Enter your email (for git configuration): \" EMAIL" Date: Thu, 27 Feb 2025 10:45:45 -0500 Subject: [PATCH 1/7] Check if fixes --- src/transformers/modeling_utils.py | 53 ++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 1553287c92a0..6175c89c6e88 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -683,6 +683,57 @@ def _find_disjoint(tensors: List[Set[str]], state_dict: Dict[str, torch.Tensor]) return shared_tensors, disjoint_tensors +def _load_state_dict_into_zero3_model(model_to_load, state_dict, start_prefix, assign_to_params_buffers=False): + """ + Loads state dict into a model specifically for Zero3, since DeepSpeed does not support the `transformers` + tensor parallelism API. + + Nearly identical code to PyTorch's `_load_from_state_dict` + """ + # copy state_dict so `_load_state_dict_into_zero3_model` can modify it + metadata = getattr(state_dict, "_metadata", None) + state_dict = state_dict.copy() + if metadata is not None: + state_dict._metadata = metadata + + error_msgs = [] + + # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants + # so we need to apply the function recursively. + def load(module: nn.Module, state_dict, prefix="", assign_to_params_buffers=False): + local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) + local_metadata["assign_to_params_buffers"] = assign_to_params_buffers + + args = (state_dict, prefix, local_metadata, True, [], [], error_msgs) + # Parameters of module and children will start with prefix. We can exit early if there are none in this + # state_dict + if is_deepspeed_zero3_enabled() and len([key for key in state_dict if key.startswith(prefix)]) > 0: + import deepspeed + + # In sharded models, each shard has only part of the full state_dict, so only gather + # parameters that are in the current state_dict. + named_parameters = dict(module.named_parameters(prefix=prefix[:-1], recurse=False)) + params_to_gather = [named_parameters[k] for k in state_dict.keys() if k in named_parameters] + if len(params_to_gather) > 0: + # because zero3 puts placeholders in model params, this context + # manager gathers (unpartitions) the params of the current layer, then loads from + # the state dict and then re-partitions them again + with deepspeed.zero.GatheredParameters(params_to_gather, modifier_rank=0): + if torch.distributed.get_rank() == 0: + module._load_from_state_dict(*args) + + for name, child in module._modules.items(): + if child is not None: + load(child, state_dict, prefix + name + ".", assign_to_params_buffers) + + load(model_to_load, state_dict, prefix=start_prefix, assign_to_params_buffers=assign_to_params_buffers) + # Delete `state_dict` so it could be collected by GC earlier. Note that `state_dict` is a copy of the argument, so + # it's safe to delete it. + del state_dict + + return error_msgs + + def _find_identical(tensors: List[Set[str]], state_dict: Dict[str, torch.Tensor]) -> Tuple[List[Set[str]], Set[str]]: shared_tensors = [] identical = [] @@ -4991,6 +5042,8 @@ def _load_pretrained_model( shard_file=shard_file, ) error_msgs += new_error_msgs + elif is_deepspeed_zero3_enabled(): + fixed_state_dict = cls._fix_state_dict_keys_on_load(state_dict) else: state_dict = load_state_dict(shard_file, map_location="cpu", weights_only=weights_only) # Sharded checkpoint or whole but low_cpu_mem_usage==True From 25c5ef4bc4fd1c0031d97d9339b0861712855abe Mon Sep 17 00:00:00 2001 From: "[[ -z $EMAIL ]] && read -e -p \"Enter your email (for git configuration): \" EMAIL" Date: Thu, 27 Feb 2025 10:55:49 -0500 Subject: [PATCH 2/7] Fix zero3 loading --- src/transformers/modeling_utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 6175c89c6e88..740226bc6b47 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -4959,7 +4959,10 @@ def _load_pretrained_model( ) # at this point the state dict should be on cpu, we don't need to actually read it fixed_state_dict = model_to_load._fix_state_dict_keys_on_load(state_dict) - model_to_load.load_state_dict(fixed_state_dict, strict=False, assign=assign_to_params_buffers) + if is_deepspeed_zero3_enabled(): + error_msgs += _load_state_dict_into_zero3_model(model_to_load, fixed_state_dict, start_prefix, assign_to_params_buffers) + else: + model_to_load.load_state_dict(fixed_state_dict, strict=False, assign=assign_to_params_buffers) else: # This should always be a list but, just to be sure. if not isinstance(resolved_archive_file, list): From 3cea82ea003d91ad7d648234e6a609ebb161c7ac Mon Sep 17 00:00:00 2001 From: "[[ -z $EMAIL ]] && read -e -p \"Enter your email (for git configuration): \" EMAIL" Date: Thu, 27 Feb 2025 11:02:46 -0500 Subject: [PATCH 3/7] Quality --- src/transformers/modeling_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 740226bc6b47..538c51c91362 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -4960,7 +4960,9 @@ def _load_pretrained_model( # at this point the state dict should be on cpu, we don't need to actually read it fixed_state_dict = model_to_load._fix_state_dict_keys_on_load(state_dict) if is_deepspeed_zero3_enabled(): - error_msgs += _load_state_dict_into_zero3_model(model_to_load, fixed_state_dict, start_prefix, assign_to_params_buffers) + error_msgs += _load_state_dict_into_zero3_model( + model_to_load, fixed_state_dict, start_prefix, assign_to_params_buffers + ) else: model_to_load.load_state_dict(fixed_state_dict, strict=False, assign=assign_to_params_buffers) else: From 2818e6c21469b9eff867b9493b35eecc11bf6122 Mon Sep 17 00:00:00 2001 From: "[[ -z $EMAIL ]] && read -e -p \"Enter your email (for git configuration): \" EMAIL" Date: Thu, 27 Feb 2025 11:15:35 -0500 Subject: [PATCH 4/7] Fix marc nit --- src/transformers/modeling_utils.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 538c51c91362..fa8206293b68 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -5047,8 +5047,6 @@ def _load_pretrained_model( shard_file=shard_file, ) error_msgs += new_error_msgs - elif is_deepspeed_zero3_enabled(): - fixed_state_dict = cls._fix_state_dict_keys_on_load(state_dict) else: state_dict = load_state_dict(shard_file, map_location="cpu", weights_only=weights_only) # Sharded checkpoint or whole but low_cpu_mem_usage==True @@ -5057,7 +5055,12 @@ def _load_pretrained_model( model_to_load, state_dict, start_prefix ) fixed_state_dict = model_to_load._fix_state_dict_keys_on_load(state_dict) - model_to_load.load_state_dict(fixed_state_dict, strict=False, assign=assign_to_params_buffers) + if is_deepspeed_zero3_enabled(): + error_msgs += _load_state_dict_into_zero3_model( + model_to_load, fixed_state_dict, start_prefix, assign_to_params_buffers + ) + else: + model_to_load.load_state_dict(fixed_state_dict, strict=False, assign=assign_to_params_buffers) # force memory release del state_dict gc.collect() From b07ca4d6b71808127ef41f0bb46ac2c5c970f02c Mon Sep 17 00:00:00 2001 From: "[[ -z $EMAIL ]] && read -e -p \"Enter your email (for git configuration): \" EMAIL" Date: Thu, 27 Feb 2025 11:26:36 -0500 Subject: [PATCH 5/7] Add fast tests --- tests/deepspeed/test_deepspeed.py | 40 ++++++++++++++++++++++++++++++- 1 file changed, 39 insertions(+), 1 deletion(-) diff --git a/tests/deepspeed/test_deepspeed.py b/tests/deepspeed/test_deepspeed.py index 28ab70059091..80a926f08db5 100644 --- a/tests/deepspeed/test_deepspeed.py +++ b/tests/deepspeed/test_deepspeed.py @@ -170,7 +170,6 @@ def parameterized_custom_name_func(func, param_num, param): @require_deepspeed -@require_torch_accelerator class CoreIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon): """ Testing non-Trainer DeepSpeed integration @@ -194,6 +193,42 @@ def tearDown(self): # reset the ds config global so that tests state doesn't leak unset_hf_deepspeed_config() + def test_init_zero3(self): + # test that zero.Init() works correctly + ds_config = { + "train_batch_size": 1, + "zero_optimization": { + "stage": 3, + }, + } + + dschf = HfDeepSpeedConfig(ds_config) + + self.assertTrue(dschf.is_zero3()) + self.assertTrue(is_deepspeed_zero3_enabled()) + + with LoggingLevel(logging.INFO): + with mockenv_context(**self.dist_env_1_gpu): + logger = logging.get_logger("transformers.modeling_utils") + with CaptureLogger(logger) as cl: + AutoModel.from_pretrained(T5_TINY) + self.assertIn("Detected DeepSpeed ZeRO-3", cl.out) + + # now remove zero optimization + del ds_config["zero_optimization"] + dschf = HfDeepSpeedConfig(ds_config) + + self.assertFalse(dschf.is_zero3()) + self.assertFalse(is_deepspeed_zero3_enabled()) + + with LoggingLevel(logging.INFO): + with mockenv_context(**self.dist_env_1_gpu): + logger = logging.get_logger("transformers.modeling_utils") + with CaptureLogger(logger) as cl: + AutoModel.from_pretrained(T5_TINY) + self.assertNotIn("Detected DeepSpeed ZeRO-3", cl.out) + + @require_torch_accelerator def test_init_zero3_fp16(self): # test that zero.Init() works correctly under zero3/fp16 ds_config = { @@ -201,6 +236,9 @@ def test_init_zero3_fp16(self): "zero_optimization": { "stage": 3, }, + "fp16": { + "enabled": True, + }, } dschf = HfDeepSpeedConfig(ds_config) From c34ab8f4cdc4f9d2c3b78c1425c03b3472ae7723 Mon Sep 17 00:00:00 2001 From: "[[ -z $EMAIL ]] && read -e -p \"Enter your email (for git configuration): \" EMAIL" Date: Thu, 27 Feb 2025 11:29:06 -0500 Subject: [PATCH 6/7] Migrate to integrations.deepspeed rather than modeling_utils --- src/transformers/integrations/deepspeed.py | 52 ++++++++++++++++++++++ src/transformers/modeling_utils.py | 52 +--------------------- 2 files changed, 53 insertions(+), 51 deletions(-) diff --git a/src/transformers/integrations/deepspeed.py b/src/transformers/integrations/deepspeed.py index e4742ecd1bfb..1b51a531645d 100644 --- a/src/transformers/integrations/deepspeed.py +++ b/src/transformers/integrations/deepspeed.py @@ -27,6 +27,7 @@ if is_torch_available(): import torch + from torch import nn logger = logging.get_logger(__name__) @@ -305,6 +306,57 @@ def deepspeed_config(): return None +def _load_state_dict_into_zero3_model(model_to_load, state_dict, start_prefix, assign_to_params_buffers=False): + """ + Loads state dict into a model specifically for Zero3, since DeepSpeed does not support the `transformers` + tensor parallelism API. + + Nearly identical code to PyTorch's `_load_from_state_dict` + """ + # copy state_dict so `_load_state_dict_into_zero3_model` can modify it + metadata = getattr(state_dict, "_metadata", None) + state_dict = state_dict.copy() + if metadata is not None: + state_dict._metadata = metadata + + error_msgs = [] + + # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants + # so we need to apply the function recursively. + def load(module: nn.Module, state_dict, prefix="", assign_to_params_buffers=False): + local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) + local_metadata["assign_to_params_buffers"] = assign_to_params_buffers + + args = (state_dict, prefix, local_metadata, True, [], [], error_msgs) + # Parameters of module and children will start with prefix. We can exit early if there are none in this + # state_dict + if is_deepspeed_zero3_enabled() and len([key for key in state_dict if key.startswith(prefix)]) > 0: + import deepspeed + + # In sharded models, each shard has only part of the full state_dict, so only gather + # parameters that are in the current state_dict. + named_parameters = dict(module.named_parameters(prefix=prefix[:-1], recurse=False)) + params_to_gather = [named_parameters[k] for k in state_dict.keys() if k in named_parameters] + if len(params_to_gather) > 0: + # because zero3 puts placeholders in model params, this context + # manager gathers (unpartitions) the params of the current layer, then loads from + # the state dict and then re-partitions them again + with deepspeed.zero.GatheredParameters(params_to_gather, modifier_rank=0): + if torch.distributed.get_rank() == 0: + module._load_from_state_dict(*args) + + for name, child in module._modules.items(): + if child is not None: + load(child, state_dict, prefix + name + ".", assign_to_params_buffers) + + load(model_to_load, state_dict, prefix=start_prefix, assign_to_params_buffers=assign_to_params_buffers) + # Delete `state_dict` so it could be collected by GC earlier. Note that `state_dict` is a copy of the argument, so + # it's safe to delete it. + del state_dict + + return error_msgs + + def deepspeed_optim_sched(trainer, hf_deepspeed_config, args, num_training_steps, model_parameters): """ A convenience wrapper that deals with optimizer and lr scheduler configuration. diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index fa8206293b68..62125921eb11 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -51,6 +51,7 @@ from .dynamic_module_utils import custom_object_save from .generation import CompileConfig, GenerationConfig, GenerationMixin from .integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled +from .integrations.deepspeed import _load_state_dict_into_zero3_model from .integrations.flash_attention import flash_attention_forward from .integrations.flex_attention import flex_attention_forward from .integrations.sdpa_attention import sdpa_attention_forward @@ -683,57 +684,6 @@ def _find_disjoint(tensors: List[Set[str]], state_dict: Dict[str, torch.Tensor]) return shared_tensors, disjoint_tensors -def _load_state_dict_into_zero3_model(model_to_load, state_dict, start_prefix, assign_to_params_buffers=False): - """ - Loads state dict into a model specifically for Zero3, since DeepSpeed does not support the `transformers` - tensor parallelism API. - - Nearly identical code to PyTorch's `_load_from_state_dict` - """ - # copy state_dict so `_load_state_dict_into_zero3_model` can modify it - metadata = getattr(state_dict, "_metadata", None) - state_dict = state_dict.copy() - if metadata is not None: - state_dict._metadata = metadata - - error_msgs = [] - - # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants - # so we need to apply the function recursively. - def load(module: nn.Module, state_dict, prefix="", assign_to_params_buffers=False): - local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) - local_metadata["assign_to_params_buffers"] = assign_to_params_buffers - - args = (state_dict, prefix, local_metadata, True, [], [], error_msgs) - # Parameters of module and children will start with prefix. We can exit early if there are none in this - # state_dict - if is_deepspeed_zero3_enabled() and len([key for key in state_dict if key.startswith(prefix)]) > 0: - import deepspeed - - # In sharded models, each shard has only part of the full state_dict, so only gather - # parameters that are in the current state_dict. - named_parameters = dict(module.named_parameters(prefix=prefix[:-1], recurse=False)) - params_to_gather = [named_parameters[k] for k in state_dict.keys() if k in named_parameters] - if len(params_to_gather) > 0: - # because zero3 puts placeholders in model params, this context - # manager gathers (unpartitions) the params of the current layer, then loads from - # the state dict and then re-partitions them again - with deepspeed.zero.GatheredParameters(params_to_gather, modifier_rank=0): - if torch.distributed.get_rank() == 0: - module._load_from_state_dict(*args) - - for name, child in module._modules.items(): - if child is not None: - load(child, state_dict, prefix + name + ".", assign_to_params_buffers) - - load(model_to_load, state_dict, prefix=start_prefix, assign_to_params_buffers=assign_to_params_buffers) - # Delete `state_dict` so it could be collected by GC earlier. Note that `state_dict` is a copy of the argument, so - # it's safe to delete it. - del state_dict - - return error_msgs - - def _find_identical(tensors: List[Set[str]], state_dict: Dict[str, torch.Tensor]) -> Tuple[List[Set[str]], Set[str]]: shared_tensors = [] identical = [] From 6984c71b801fdfaaf6629bb833209289416c1ae6 Mon Sep 17 00:00:00 2001 From: "[[ -z $EMAIL ]] && read -e -p \"Enter your email (for git configuration): \" EMAIL" Date: Fri, 28 Feb 2025 09:47:21 -0500 Subject: [PATCH 7/7] Style --- src/transformers/modeling_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 58dba7a5460b..5765e5f23d97 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -4911,7 +4911,7 @@ def _load_pretrained_model( mismatched_names = [name for name, _, _ in mismatched_keys] fixed_state_dict = {k: v for k, v in state_dict.items() if k not in mismatched_names} fixed_state_dict = model_to_load._fix_state_dict_keys_on_load(fixed_state_dict) - + if is_deepspeed_zero3_enabled(): error_msgs += _load_state_dict_into_zero3_model( model_to_load, fixed_state_dict, start_prefix, assign_to_params_buffers