From 51a75681722809b4b393d2c7cda17867448a7e4a Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Fri, 12 Apr 2024 10:45:51 +0200 Subject: [PATCH 1/3] Checkpoint naming scheme --- src/transformers/models/llama/convert_llama_weights_to_hf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/llama/convert_llama_weights_to_hf.py b/src/transformers/models/llama/convert_llama_weights_to_hf.py index da2cfdb0f91..9f9a375a7fd 100644 --- a/src/transformers/models/llama/convert_llama_weights_to_hf.py +++ b/src/transformers/models/llama/convert_llama_weights_to_hf.py @@ -138,7 +138,7 @@ def permute(w, n_heads, dim1=dim, dim2=dim): else: # Sharded loaded = [ - torch.load(os.path.join(input_base_path, f"consolidated.{i:02d}.pth"), map_location="cpu") + torch.load(os.path.join(input_base_path, f"consolidated.{i:01d}.pth"), map_location="cpu") for i in range(num_shards) ] param_count = 0 From fb48f2fe25207db8ae0fbf44f808a8ea976b011b Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Fri, 12 Apr 2024 10:46:12 +0200 Subject: [PATCH 2/3] Missing n_head to permute --- src/transformers/models/llama/convert_llama_weights_to_hf.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/llama/convert_llama_weights_to_hf.py b/src/transformers/models/llama/convert_llama_weights_to_hf.py index 9f9a375a7fd..c0b102b1d52 100644 --- a/src/transformers/models/llama/convert_llama_weights_to_hf.py +++ b/src/transformers/models/llama/convert_llama_weights_to_hf.py @@ -183,7 +183,8 @@ def permute(w, n_heads, dim1=dim, dim2=dim): for i in range(num_shards) ], dim=0, - ).reshape(dim, dim) + ).reshape(dim, dim), + n_heads=n_heads ) state_dict[f"model.layers.{layer_i}.self_attn.k_proj.weight"] = permute( torch.cat( From 1f37d2bf1f3af192dc53761b90db53f64bb8586c Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Fri, 12 Apr 2024 10:46:27 +0200 Subject: [PATCH 3/3] Fix embeddings concat for version 3 --- src/transformers/models/llama/convert_llama_weights_to_hf.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/llama/convert_llama_weights_to_hf.py b/src/transformers/models/llama/convert_llama_weights_to_hf.py index c0b102b1d52..9a360d2524c 100644 --- a/src/transformers/models/llama/convert_llama_weights_to_hf.py +++ b/src/transformers/models/llama/convert_llama_weights_to_hf.py @@ -238,10 +238,11 @@ def permute(w, n_heads, dim1=dim, dim2=dim): "lm_head.weight": loaded["output.weight"], } else: + concat_dim = 0 if llama_version == 3 else 1 state_dict = { "model.norm.weight": loaded[0]["norm.weight"], "model.embed_tokens.weight": torch.cat( - [loaded[i]["tok_embeddings.weight"] for i in range(num_shards)], dim=1 + [loaded[i]["tok_embeddings.weight"] for i in range(num_shards)], dim=concat_dim ), "lm_head.weight": torch.cat([loaded[i]["output.weight"] for i in range(num_shards)], dim=0), }