Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support internlm3 #11233

Merged
merged 2 commits into from
Jan 16, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2819,6 +2819,66 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
return [(self.map_tensor_name(name), data_torch)]


@Model.register("InternLM3ForCausalLM")
class InternLM3Model(Model):
model_arch = gguf.MODEL_ARCH.LLAMA

def set_vocab(self):
tokens, scores, toktypes = self._create_vocab_sentencepiece()

self.gguf_writer.add_tokenizer_model("llama")
self.gguf_writer.add_tokenizer_pre("default")
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_scores(scores)
self.gguf_writer.add_token_types(toktypes)

special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))

tokenizer_config_file = self.dir_model / 'tokenizer_config.json'
if tokenizer_config_file.is_file():
with open(tokenizer_config_file, "r", encoding="utf-8") as f:
tokenizer_config_json = json.load(f)
if "add_prefix_space" in tokenizer_config_json:
self.gguf_writer.add_add_space_prefix(tokenizer_config_json["add_prefix_space"])

if "added_tokens_decoder" in tokenizer_config_json:
for token_id, token_data in tokenizer_config_json["added_tokens_decoder"].items():
if token_data.get("special"):
token_id = int(token_id)
token = token_data["content"]
special_vocab._set_special_token(token, token_id)
# update eos token
if token == '<|im_end|>' and "eos" in special_vocab.special_token_ids:
special_vocab.special_token_ids["eos"] = token_id

special_vocab.add_to_gguf(self.gguf_writer)

def set_gguf_parameters(self):
super().set_gguf_parameters()
hparams = self.hparams
self.gguf_writer.add_vocab_size(hparams["vocab_size"])

if "head_dim" in hparams:
rope_dim = hparams["head_dim"]
else:
rope_dim = hparams["hidden_size"] // hparams["num_attention_heads"]
self.gguf_writer.add_rope_dimension_count(rope_dim)

if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]:
if self.hparams["rope_scaling"].get("type") == "linear" or self.hparams["rope_scaling"].get("rope_type") == "linear":
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"])

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
n_head = self.hparams["num_attention_heads"]
n_kv_head = self.hparams.get("num_key_value_heads")
if name.endswith(("q_proj.weight", "q_proj.bias")):
data_torch = LlamaModel.permute(data_torch, n_head, n_head)
if name.endswith(("k_proj.weight", "k_proj.bias")):
data_torch = LlamaModel.permute(data_torch, n_head, n_kv_head)
return [(self.map_tensor_name(name), data_torch)]


@Model.register("BertModel", "BertForMaskedLM", "CamembertModel")
class BertModel(Model):
model_arch = gguf.MODEL_ARCH.BERT
Expand Down
Loading