Skip to content

Commit

Permalink
Merge pull request #2 from huggingface/add-llama3-convert-70b
Browse files Browse the repository at this point in the history
Add llama3 convert 70b
  • Loading branch information
pcuenca authored Apr 15, 2024
2 parents 3e4fac9 + 1f37d2b commit d95e60c
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions src/transformers/models/llama/convert_llama_weights_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def permute(w, n_heads, dim1=dim, dim2=dim):
else:
# Sharded
loaded = [
torch.load(os.path.join(input_base_path, f"consolidated.{i:02d}.pth"), map_location="cpu")
torch.load(os.path.join(input_base_path, f"consolidated.{i:01d}.pth"), map_location="cpu")
for i in range(num_shards)
]
param_count = 0
Expand Down Expand Up @@ -190,7 +190,8 @@ def permute(w, n_heads, dim1=dim, dim2=dim):
for i in range(num_shards)
],
dim=0,
).reshape(dim, dim)
).reshape(dim, dim),
n_heads=n_heads
)
state_dict[f"model.layers.{layer_i}.self_attn.k_proj.weight"] = permute(
torch.cat(
Expand Down Expand Up @@ -244,10 +245,11 @@ def permute(w, n_heads, dim1=dim, dim2=dim):
"lm_head.weight": loaded["output.weight"],
}
else:
concat_dim = 0 if llama_version == 3 else 1
state_dict = {
"model.norm.weight": loaded[0]["norm.weight"],
"model.embed_tokens.weight": torch.cat(
[loaded[i]["tok_embeddings.weight"] for i in range(num_shards)], dim=1
[loaded[i]["tok_embeddings.weight"] for i in range(num_shards)], dim=concat_dim
),
"lm_head.weight": torch.cat([loaded[i]["output.weight"] for i in range(num_shards)], dim=0),
}
Expand Down

0 comments on commit d95e60c

Please sign in to comment.