From e109aee9108e87e5d9364e1522f4546cbda869f0 Mon Sep 17 00:00:00 2001 From: Aswin John Mathews <81309834+amathews-amd@users.noreply.github.com> Date: Mon, 5 Aug 2024 14:32:25 -0500 Subject: [PATCH] Revert "Resize embeds with DeepSpeed (#32214)" This reverts commit db8544a7323b5ef0e6d7f9bd01a61b73ad8c1e16. --- src/transformers/modeling_utils.py | 18 ++++-------------- 1 file changed, 4 insertions(+), 14 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 557624f78b66..8f1ad56f6999 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1980,22 +1980,12 @@ def resize_token_embeddings( if new_num_tokens is None and pad_to_multiple_of is None: return model_embeds - # Since we are basically resuing the same old embeddings with new weight values, gathering is required - is_quantized = hasattr(self, "hf_quantizer") and self.hf_quantizer is not None - if is_deepspeed_zero3_enabled() and not is_quantized: - import deepspeed - - with deepspeed.zero.GatheredParameters(model_embeds.weight, modifier_rank=None): - vocab_size = model_embeds.weight.shape[0] - else: - vocab_size = model_embeds.weight.shape[0] - # Update base model and current model config if hasattr(self.config, "text_config"): - self.config.text_config.vocab_size = vocab_size + self.config.text_config.vocab_size = model_embeds.weight.shape[0] else: - self.config.vocab_size = vocab_size - self.vocab_size = vocab_size + self.config.vocab_size = model_embeds.weight.shape[0] + self.vocab_size = model_embeds.weight.shape[0] # Tie weights again if needed self.tie_weights() @@ -2149,7 +2139,7 @@ def _get_resized_embeddings( params = [old_embeddings.weight, new_embeddings.weight] with deepspeed.zero.GatheredParameters(params, modifier_rank=0): - old_embeddings.weight = new_embeddings.weight + old_embeddings.weight.data = new_embeddings.weight.data old_embeddings.num_embeddings = new_embeddings.weight.data.shape[0] # If the new number of tokens is smaller than the original `padding_idx`, the `padding_idx`