Skip to content

Commit

Permalink
Revert "Resize embeds with DeepSpeed (huggingface#32214)"
Browse files Browse the repository at this point in the history
This reverts commit db8544a.
  • Loading branch information
amathews-amd authored Aug 5, 2024
1 parent ca3b072 commit e109aee
Showing 1 changed file with 4 additions and 14 deletions.
18 changes: 4 additions & 14 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1980,22 +1980,12 @@ def resize_token_embeddings(
if new_num_tokens is None and pad_to_multiple_of is None:
return model_embeds

# Since we are basically resuing the same old embeddings with new weight values, gathering is required
is_quantized = hasattr(self, "hf_quantizer") and self.hf_quantizer is not None
if is_deepspeed_zero3_enabled() and not is_quantized:
import deepspeed

with deepspeed.zero.GatheredParameters(model_embeds.weight, modifier_rank=None):
vocab_size = model_embeds.weight.shape[0]
else:
vocab_size = model_embeds.weight.shape[0]

# Update base model and current model config
if hasattr(self.config, "text_config"):
self.config.text_config.vocab_size = vocab_size
self.config.text_config.vocab_size = model_embeds.weight.shape[0]
else:
self.config.vocab_size = vocab_size
self.vocab_size = vocab_size
self.config.vocab_size = model_embeds.weight.shape[0]
self.vocab_size = model_embeds.weight.shape[0]

# Tie weights again if needed
self.tie_weights()
Expand Down Expand Up @@ -2149,7 +2139,7 @@ def _get_resized_embeddings(

params = [old_embeddings.weight, new_embeddings.weight]
with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
old_embeddings.weight = new_embeddings.weight
old_embeddings.weight.data = new_embeddings.weight.data
old_embeddings.num_embeddings = new_embeddings.weight.data.shape[0]

# If the new number of tokens is smaller than the original `padding_idx`, the `padding_idx`
Expand Down

0 comments on commit e109aee

Please sign in to comment.