Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added support for llama v2 and codellama in weight conversion for issue #28241 #28767

Closed
wants to merge 3 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions src/transformers/models/llama/convert_llama_weights_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def write_json(text, path):
json.dump(text, f)


def write_model(model_path, input_base_path, model_size, tokenizer_path=None, safe_serialization=True):
def write_model(model_path, input_base_path, model_size, llama_version, tokenizer_path=None, safe_serialization=True):
# for backward compatibility, before you needed the repo to be called `my_repo/model_size`
if not os.path.isfile(os.path.join(input_base_path, "params.json")):
input_base_path = os.path.join(input_base_path, model_size)
Expand All @@ -99,9 +99,11 @@ def write_model(model_path, input_base_path, model_size, tokenizer_path=None, sa
dims_per_head = dim // n_heads
base = params.get("rope_theta", 10000.0)
inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head))
if base > 10000.0:
if base > 10000.0 or llama_version == "code":
max_position_embeddings = 16384
else:
elif llama_version == "v2":
max_position_embeddings = 4096
else: # defaults to v1 LLaMa
max_position_embeddings = 2048

tokenizer_class = LlamaTokenizer if LlamaTokenizerFast is None else LlamaTokenizerFast
Expand Down Expand Up @@ -296,6 +298,11 @@ def main():
choices=["7B", "7Bf", "13B", "13Bf", "30B", "34B", "65B", "70B", "70Bf", "tokenizer_only"],
help="'f' models correspond to the finetuned versions, and are specific to the Llama2 official release. For more details on Llama2, checkout the original repo: https://huggingface.co/meta-llama",
)
parser.add_argument(
"--llama_version",
choices=["v1", "v2", "code"],
help="Specifies which LLaMa version to use, as each version has a different maximum context length. LLaMa v1 has a maximum context length of 2048, LLaMa v2 has a maximum context length of 4096, and CodeLLaMa has a maximum context length of 16384. Defaults to v1.",
)
parser.add_argument(
"--output_dir",
help="Location to write HF model and tokenizer",
Expand All @@ -308,6 +315,7 @@ def main():
model_path=args.output_dir,
input_base_path=args.input_dir,
model_size=args.model_size,
llama_version=args.llama_version,
safe_serialization=args.safe_serialization,
tokenizer_path=spm_path,
)
Expand Down