Skip to content

Commit

Permalink
update conversion script
Browse files Browse the repository at this point in the history
  • Loading branch information
ArthurZucker committed Feb 8, 2024
1 parent a5bb7a2 commit 2126d1c
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 7 deletions.
4 changes: 2 additions & 2 deletions src/transformers/models/gemma/configuration_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class GemmaConfig(PretrainedConfig):
Args:
vocab_size (`int`, *optional*, defaults to 32000):
vocab_size (`int`, *optional*, defaults to 256000):
Vocabulary size of the Gemma model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`GemmaModel`]
hidden_size (`int`, *optional*, defaults to 4096):
Expand Down Expand Up @@ -118,7 +118,7 @@ class GemmaConfig(PretrainedConfig):

def __init__(
self,
vocab_size=256128,
vocab_size=256000,
hidden_size=3072,
intermediate_size=24576,
num_hidden_layers=28,
Expand Down
29 changes: 24 additions & 5 deletions src/transformers/models/gemma/convert_gemma_weights_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import torch

from transformers import FlaxGemmaForCausalLM, GemmaConfig, GemmaForCausalLM, GemmaTokenizer
from transformers import GemmaConfig, GemmaForCausalLM, GemmaTokenizer


try:
Expand Down Expand Up @@ -61,7 +61,6 @@
gemma_7b_config = GemmaConfig()

CONFIG_MAPPING = {"2B": gemma_2b_config, "7B": gemma_7b_config}

LAYER_NAME_MAPPING = {"embedder.weight": "model.embed_tokens.weight"}


Expand All @@ -71,7 +70,7 @@ def write_model(save_path, input_base_path, config, safe_serialization=True):
num_kv_heads = config.num_key_value_heads
head_dim = config.head_dim

print(f"Fetching all parameters from the checkpoint at {input_base_path}.")
print(f"Fetching all parameters from the checkpoint at '{input_base_path}'")
model_state_dict = torch.load(input_base_path, map_location="cpu")["model_state_dict"]
model_state_dict.pop("freqs_cis")

Expand Down Expand Up @@ -107,7 +106,27 @@ def write_model(save_path, input_base_path, config, safe_serialization=True):
model.config.torch_dtype = torch.float32
del model.config._name_or_path
print("Saving in the Transformers format.")
model.save_pretrained(save_path, safe_serialization=safe_serialization)
push_to_hub = True
if push_to_hub:
print(f"pushing the model to {save_path}")
response = input("Please enter yes or no: ").lower().strip()
result = response == "yes"
if result:
model.push_to_hub(save_path, safe_serialization=safe_serialization, private=True)
print(f"Pushed float32")

fp16_model = model.to(torch.float16)
fp16_model.push_to_hub(save_path, safe_serialization=safe_serialization, revision="float16", private=True)
del fp16_model
print(f"Pushed float16")

bf16_model = model.to(torch.bfloat16)
bf16_model.push_to_hub(save_path, safe_serialization=safe_serialization, revision="bfloat16", private=True)
print(f"Pushed bfloat16")
else:
model.save_pretrained(save_path, safe_serialization=safe_serialization)




def write_tokenizer(input_tokenizer_path, save_path):
Expand Down Expand Up @@ -137,7 +156,7 @@ def main():
)
parser.add_argument(
"--output_dir",
default="gemma_7b",
default="gg-hf/gemma-7b",
help="Location to write HF model and tokenizer",
)
parser.add_argument(
Expand Down

0 comments on commit 2126d1c

Please sign in to comment.