Skip to content

Commit

Permalink
model: dbrx convert to gguf
Browse files Browse the repository at this point in the history
  • Loading branch information
phymbert committed Apr 6, 2024
1 parent a307375 commit 1d8de31
Show file tree
Hide file tree
Showing 5 changed files with 123 additions and 40 deletions.
54 changes: 54 additions & 0 deletions convert-hf-to-gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1424,6 +1424,60 @@ def write_tensors(self):
self.gguf_writer.add_tensor(new_name, data)


@Model.register("DbrxForCausalLM")
class Qwen2MoeModel(Model):
model_arch = gguf.MODEL_ARCH.DBRX

def set_gguf_parameters(self):
ffn_config = self.hparams["ffn_config"]
attn_config = self.hparams["attn_config"]
self.gguf_writer.add_name(self.hparams["model_type"])
self.gguf_writer.add_context_length(self.hparams["max_seq_len"])
self.gguf_writer.add_embedding_length(self.hparams["d_model"])
self.gguf_writer.add_block_count(self.hparams["n_layers"])
self.gguf_writer.add_head_count(self.hparams["n_heads"])
self.gguf_writer.add_head_count_kv(attn_config["kv_n_heads"])
self.gguf_writer.add_rope_freq_base(attn_config["rope_theta"])
self.gguf_writer.add_clip_kqv(attn_config["clip_qkv"])
self.gguf_writer.add_file_type(self.ftype)

self.gguf_writer.add_expert_count(ffn_config["moe_num_experts"])
self.gguf_writer.add_expert_used_count(ffn_config["moe_top_k"])

def _set_vocab_gpt2(self):
dir_model = self.dir_model
hparams = self.hparams
tokens: list[str] = []
toktypes: list[int] = []

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(dir_model)
vocab_size = tokenizer.vocab_size

reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in tokenizer.get_vocab().items()}
added_vocab = tokenizer.get_added_vocab()

for i in range(vocab_size):
if i not in reverse_vocab:
tokens.append(f"[PAD{i}]")
toktypes.append(gguf.TokenType.USER_DEFINED)
elif reverse_vocab[i] in added_vocab:
tokens.append(reverse_vocab[i])
if tokenizer.added_tokens_decoder[i].special:
toktypes.append(gguf.TokenType.CONTROL)
else:
toktypes.append(gguf.TokenType.USER_DEFINED)
else:
tokens.append(reverse_vocab[i])
toktypes.append(gguf.TokenType.NORMAL)

self.gguf_writer.add_tokenizer_model("gpt2")
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_types(toktypes)

special_vocab = gguf.SpecialVocab(dir_model, load_merges=True)
special_vocab.add_to_gguf(self.gguf_writer)

@Model.register("MiniCPMForCausalLM")
class MiniCPMModel(Model):
model_arch = gguf.MODEL_ARCH.MINICPM
Expand Down
17 changes: 17 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class Attention:
LAYERNORM_EPS = "{arch}.attention.layer_norm_epsilon"
LAYERNORM_RMS_EPS = "{arch}.attention.layer_norm_rms_epsilon"
CAUSAL = "{arch}.attention.causal"
CLIP_KQV = "{arch}.attention.clip_kqv"

class Rope:
DIMENSION_COUNT = "{arch}.rope.dimension_count"
Expand Down Expand Up @@ -125,6 +126,7 @@ class MODEL_ARCH(IntEnum):
MAMBA = auto()
XVERSE = auto()
COMMAND_R = auto()
DBRX = auto()


class MODEL_TENSOR(IntEnum):
Expand Down Expand Up @@ -194,6 +196,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.MAMBA: "mamba",
MODEL_ARCH.XVERSE: "xverse",
MODEL_ARCH.COMMAND_R: "command-r",
MODEL_ARCH.DBRX: "dbrx",
}

TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
Expand Down Expand Up @@ -639,6 +642,20 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.DBRX: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_QKV,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_NORM_2,
MODEL_TENSOR.FFN_GATE_INP,
MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
MODEL_TENSOR.LAYER_OUT_NORM,
],
# TODO
}

Expand Down
3 changes: 3 additions & 0 deletions gguf-py/gguf/gguf_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,9 @@ def add_layer_norm_rms_eps(self, value: float) -> None:
def add_causal_attention(self, value: bool) -> None:
self.add_bool(Keys.Attention.CAUSAL.format(arch=self.arch), value)

def add_clip_kqv(self, value: int) -> None:
self.add_uint32(Keys.Attention.CLIP_KQV.format(arch=self.arch), value)

def add_pooling_type(self, value: PoolingType) -> None:
self.add_uint32(Keys.LLM.POOLING_TYPE.format(arch=self.arch), value.value)

Expand Down
88 changes: 48 additions & 40 deletions gguf-py/gguf/tensor_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,35 +79,38 @@ class TensorNameMap:
block_mappings_cfg: dict[MODEL_TENSOR, tuple[str, ...]] = {
# Attention norm
MODEL_TENSOR.ATTN_NORM: (
"gpt_neox.layers.{bid}.input_layernorm", # gptneox
"transformer.h.{bid}.ln_1", # gpt2 gpt-j refact qwen
"transformer.blocks.{bid}.norm_1", # mpt
"transformer.h.{bid}.input_layernorm", # falcon7b
"h.{bid}.input_layernorm", # bloom
"transformer.h.{bid}.ln_mlp", # falcon40b
"model.layers.{bid}.input_layernorm", # llama-hf
"layers.{bid}.attention_norm", # llama-pth
"language_model.encoder.layers.{bid}.input_layernorm", # persimmon
"model.layers.{bid}.ln1", # yi
"h.{bid}.ln_1", # gpt2
"transformer.h.{bid}.ln", # phi2
"model.layers.layers.{bid}.norm", # plamo
"model.layers.{bid}.attention_norm", # internlm2
"model.layers.{bid}.norm", # mamba-qbert
"backbone.layers.{bid}.norm", # mamba
"transformer.decoder_layer.{bid}.rms_norm", # Grok
"gpt_neox.layers.{bid}.input_layernorm", # gptneox
"transformer.h.{bid}.ln_1", # gpt2 gpt-j refact qwen
"transformer.blocks.{bid}.norm_1", # mpt
"transformer.h.{bid}.input_layernorm", # falcon7b
"h.{bid}.input_layernorm", # bloom
"transformer.h.{bid}.ln_mlp", # falcon40b
"model.layers.{bid}.input_layernorm", # llama-hf
"layers.{bid}.attention_norm", # llama-pth
"language_model.encoder.layers.{bid}.input_layernorm", # persimmon
"model.layers.{bid}.ln1", # yi
"h.{bid}.ln_1", # gpt2
"transformer.h.{bid}.ln", # phi2
"model.layers.layers.{bid}.norm", # plamo
"model.layers.{bid}.attention_norm", # internlm2
"model.layers.{bid}.norm", # mamba-qbert
"backbone.layers.{bid}.norm", # mamba
"transformer.decoder_layer.{bid}.rms_norm", # Grok
"transformer.blocks.{bid}.norm_attn_norm.norm_1.weight", # DBRX
),

# Attention norm 2
MODEL_TENSOR.ATTN_NORM_2: (
"transformer.h.{bid}.ln_attn", # falcon40b
"transformer.h.{bid}.ln_attn", # falcon40b
"transformer.blocks.{bid}.norm_attn_norm.norm_2.weight", # DBRX
),

# Attention query-key-value
MODEL_TENSOR.ATTN_QKV: (
"gpt_neox.layers.{bid}.attention.query_key_value", # gptneox
"transformer.h.{bid}.attn.c_attn", # gpt2 qwen
"transformer.blocks.{bid}.attn.Wqkv", # mpt
"transformer.blocks.{bid}.norm_attn_norm.attn.Wqkv.weight", # DBRX
"transformer.h.{bid}.self_attention.query_key_value", # falcon
"h.{bid}.self_attention.query_key_value", # bloom
"language_model.encoder.layers.{bid}.self_attention.query_key_value", # persimmon
Expand Down Expand Up @@ -152,23 +155,24 @@ class TensorNameMap:

# Attention output
MODEL_TENSOR.ATTN_OUT: (
"gpt_neox.layers.{bid}.attention.dense", # gptneox
"transformer.h.{bid}.attn.c_proj", # gpt2 refact qwen
"transformer.blocks.{bid}.attn.out_proj", # mpt
"transformer.h.{bid}.self_attention.dense", # falcon
"h.{bid}.self_attention.dense", # bloom
"model.layers.{bid}.self_attn.o_proj", # llama-hf
"layers.{bid}.attention.wo", # llama-pth
"encoder.layer.{bid}.attention.output.dense", # bert
"transformer.h.{bid}.attn.out_proj", # gpt-j
"language_model.encoder.layers.{bid}.self_attention.dense", # persimmon
"model.layers.{bid}.self_attn.dense", # persimmon
"h.{bid}.attn.c_proj", # gpt2
"transformer.h.{bid}.mixer.out_proj", # phi2
"model.layers.layers.{bid}.self_attn.o_proj", # plamo
"model.layers.{bid}.attention.wo", # internlm2
"encoder.layers.{bid}.attn.out_proj", # nomic-bert
"transformer.decoder_layer.{bid}.multi_head_attention.linear"# Grok
"gpt_neox.layers.{bid}.attention.dense", # gptneox
"transformer.h.{bid}.attn.c_proj", # gpt2 refact qwen
"transformer.blocks.{bid}.attn.out_proj", # mpt
"transformer.h.{bid}.self_attention.dense", # falcon
"h.{bid}.self_attention.dense", # bloom
"model.layers.{bid}.self_attn.o_proj", # llama-hf
"layers.{bid}.attention.wo", # llama-pth
"encoder.layer.{bid}.attention.output.dense", # bert
"transformer.h.{bid}.attn.out_proj", # gpt-j
"language_model.encoder.layers.{bid}.self_attention.dense", # persimmon
"model.layers.{bid}.self_attn.dense", # persimmon
"h.{bid}.attn.c_proj", # gpt2
"transformer.h.{bid}.mixer.out_proj", # phi2
"model.layers.layers.{bid}.self_attn.o_proj", # plamo
"model.layers.{bid}.attention.wo", # internlm2
"encoder.layers.{bid}.attn.out_proj", # nomic-bert
"transformer.decoder_layer.{bid}.multi_head_attention.linear", # Grok
"transformer.blocks.{bid}.norm_attn_norm.attn.out_proj.weight", # DBRX
),

# Attention output norm
Expand Down Expand Up @@ -202,9 +206,10 @@ class TensorNameMap:
),

MODEL_TENSOR.FFN_GATE_INP: (
"layers.{bid}.feed_forward.gate", # mixtral
"model.layers.{bid}.block_sparse_moe.gate", # mixtral
"transformer.decoder_layer.{bid}.router" # Grok
"layers.{bid}.feed_forward.gate", # mixtral
"model.layers.{bid}.block_sparse_moe.gate", # mixtral
"transformer.decoder_layer.{bid}.router", # Grok
"transformer.blocks.{bid}.ffn.router.layer.weight", # DBRX
),

# Feed-forward up
Expand Down Expand Up @@ -233,6 +238,7 @@ class TensorNameMap:
MODEL_TENSOR.FFN_UP_EXP: (
"layers.{bid}.feed_forward.experts.w3", # mixtral (merged)
"transformer.decoder_layer.{bid}.moe.linear_v", # Grok (merged)
"transformer.blocks.{bid}.ffn.experts.mlp.w2", # DBRX
),

# AWQ-activation gate
Expand All @@ -251,8 +257,9 @@ class TensorNameMap:
),

MODEL_TENSOR.FFN_GATE_EXP: (
"layers.{bid}.feed_forward.experts.w1", # mixtral (merged)
"transformer.decoder_layer.{bid}.moe.linear" # Grok (merged)
"layers.{bid}.feed_forward.experts.w1", # mixtral (merged)
"transformer.decoder_layer.{bid}.moe.linear", # Grok (merged)
"transformer.blocks.{bid}.ffn.experts.mlp.v1", # DBRX
),

# Feed-forward down
Expand Down Expand Up @@ -280,6 +287,7 @@ class TensorNameMap:
MODEL_TENSOR.FFN_DOWN_EXP: (
"layers.{bid}.feed_forward.experts.w2", # mixtral (merged)
"transformer.decoder_layer.{bid}.moe.linear_1", # Grok (merged)
"transformer.blocks.{bid}.ffn.experts.mlp.w1", # DBRX
),

MODEL_TENSOR.ATTN_Q_NORM: (
Expand Down
1 change: 1 addition & 0 deletions requirements/requirements-convert-hf-to-gguf.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
-r ./requirements-convert.txt
torch~=2.1.1
einops~=0.7.0
tiktoken~=0.6.0

0 comments on commit 1d8de31

Please sign in to comment.