From 2126d1cae6e6b26456d1f0322d2db94f7eeac426 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 8 Feb 2024 19:20:21 +0900 Subject: [PATCH] update conversion script --- .../models/gemma/configuration_gemma.py | 4 +-- .../gemma/convert_gemma_weights_to_hf.py | 29 +++++++++++++++---- 2 files changed, 26 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/gemma/configuration_gemma.py b/src/transformers/models/gemma/configuration_gemma.py index 6e8b0e66ca34..ef40a1a9a14f 100644 --- a/src/transformers/models/gemma/configuration_gemma.py +++ b/src/transformers/models/gemma/configuration_gemma.py @@ -41,7 +41,7 @@ class GemmaConfig(PretrainedConfig): Args: - vocab_size (`int`, *optional*, defaults to 32000): + vocab_size (`int`, *optional*, defaults to 256000): Vocabulary size of the Gemma model. Defines the number of different tokens that can be represented by the `inputs_ids` passed when calling [`GemmaModel`] hidden_size (`int`, *optional*, defaults to 4096): @@ -118,7 +118,7 @@ class GemmaConfig(PretrainedConfig): def __init__( self, - vocab_size=256128, + vocab_size=256000, hidden_size=3072, intermediate_size=24576, num_hidden_layers=28, diff --git a/src/transformers/models/gemma/convert_gemma_weights_to_hf.py b/src/transformers/models/gemma/convert_gemma_weights_to_hf.py index f48fe48f9d65..162ad25023d9 100644 --- a/src/transformers/models/gemma/convert_gemma_weights_to_hf.py +++ b/src/transformers/models/gemma/convert_gemma_weights_to_hf.py @@ -17,7 +17,7 @@ import torch -from transformers import FlaxGemmaForCausalLM, GemmaConfig, GemmaForCausalLM, GemmaTokenizer +from transformers import GemmaConfig, GemmaForCausalLM, GemmaTokenizer try: @@ -61,7 +61,6 @@ gemma_7b_config = GemmaConfig() CONFIG_MAPPING = {"2B": gemma_2b_config, "7B": gemma_7b_config} - LAYER_NAME_MAPPING = {"embedder.weight": "model.embed_tokens.weight"} @@ -71,7 +70,7 @@ def write_model(save_path, input_base_path, config, safe_serialization=True): num_kv_heads = config.num_key_value_heads head_dim = config.head_dim - print(f"Fetching all parameters from the checkpoint at {input_base_path}.") + print(f"Fetching all parameters from the checkpoint at '{input_base_path}'") model_state_dict = torch.load(input_base_path, map_location="cpu")["model_state_dict"] model_state_dict.pop("freqs_cis") @@ -107,7 +106,27 @@ def write_model(save_path, input_base_path, config, safe_serialization=True): model.config.torch_dtype = torch.float32 del model.config._name_or_path print("Saving in the Transformers format.") - model.save_pretrained(save_path, safe_serialization=safe_serialization) + push_to_hub = True + if push_to_hub: + print(f"pushing the model to {save_path}") + response = input("Please enter yes or no: ").lower().strip() + result = response == "yes" + if result: + model.push_to_hub(save_path, safe_serialization=safe_serialization, private=True) + print(f"Pushed float32") + + fp16_model = model.to(torch.float16) + fp16_model.push_to_hub(save_path, safe_serialization=safe_serialization, revision="float16", private=True) + del fp16_model + print(f"Pushed float16") + + bf16_model = model.to(torch.bfloat16) + bf16_model.push_to_hub(save_path, safe_serialization=safe_serialization, revision="bfloat16", private=True) + print(f"Pushed bfloat16") + else: + model.save_pretrained(save_path, safe_serialization=safe_serialization) + + def write_tokenizer(input_tokenizer_path, save_path): @@ -137,7 +156,7 @@ def main(): ) parser.add_argument( "--output_dir", - default="gemma_7b", + default="gg-hf/gemma-7b", help="Location to write HF model and tokenizer", ) parser.add_argument(