Skip to content

Commit

Permalink
add mobilevlm
Browse files Browse the repository at this point in the history
  • Loading branch information
ngxson committed Jan 19, 2025
1 parent 6cabdda commit d0068ef
Show file tree
Hide file tree
Showing 9 changed files with 216 additions and 67 deletions.
66 changes: 47 additions & 19 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Literal, Sequence, TypeVar, cast
from itertools import chain

from transformers import AutoConfig
from transformers import AutoConfig, AutoImageProcessor
import math
import numpy as np
import torch
Expand Down Expand Up @@ -68,9 +68,10 @@ class Model:
dir_model_card: Path

# for vision model
vision_arch: gguf.MODEL_ARCH | None = None
preprocessor_config: dict[str, Any] | None = None
vparams: dict[str, Any] | None = None
v_tensor_map: gguf.TensorNameMap
v_tensor_map: gguf.TensorNameMap | None = None
v_tensor_names: set[str] | None

# subclasses should define this!
Expand Down Expand Up @@ -102,7 +103,6 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path,
self.metadata_override = metadata_override
self.model_name = model_name
self.dir_model_card = dir_model # overridden in convert_lora_to_gguf.py
self.preprocessor_config = self.load_preprocessor_config(self.dir_model)

# Apply heuristics to figure out typical tensor encoding based on first layer tensor encoding type
if self.ftype == gguf.LlamaFileType.GUESSED:
Expand Down Expand Up @@ -218,7 +218,7 @@ def match_model_tensor_name(self, name: str, key: gguf.MODEL_TENSOR, bid: int |

def map_tensor_name(self, name: str, try_suffixes: Sequence[str] = (".weight", ".bias")) -> str:
new_name = self.tensor_map.get_name(key=name, try_suffixes=try_suffixes)
new_name_vision = self.v_tensor_map.get_name(key=name, try_suffixes=try_suffixes)
new_name_vision = self.v_tensor_map.get_name(key=name, try_suffixes=try_suffixes) if self.v_tensor_map is not None else None
if new_name is not None:
return new_name
elif new_name_vision is not None:
Expand Down Expand Up @@ -488,14 +488,17 @@ def load_hparams(dir_model: Path):
return hparams

@staticmethod
def load_preprocessor_config(dir_model: Path):
def load_preprocessor_config(dir_or_model_id: Path | str):
# TODO: this varies vastly among models, need to handle more cases in the future
file_path = dir_model / "preprocessor_config.json"
if os.path.exists(file_path):
with open(file_path, "r", encoding="utf-8") as f:
return json.load(f)
if isinstance(dir_or_model_id, Path):
file_path = dir_or_model_id / "preprocessor_config.json"
if os.path.exists(file_path):
with open(file_path, "r", encoding="utf-8") as f:
return json.load(f)
else:
raise Exception(f"Preprocessor config not found at {file_path}")
else:
return None
return AutoImageProcessor.from_pretrained(dir_or_model_id).to_dict()

@classmethod
def register(cls, *names: str) -> Callable[[AnyModel], AnyModel]:
Expand Down Expand Up @@ -1586,16 +1589,31 @@ def prepare_tensors(self):
raise ValueError(f"Unprocessed norms: {norms}")


@Model.register("LLaMAForCausalLM", "LlamaForCausalLM", "MistralForCausalLM", "MixtralForCausalLM", "LlavaForConditionalGeneration")
@Model.register("LLaMAForCausalLM", "LlamaForCausalLM", "MistralForCausalLM", "MixtralForCausalLM", "LlavaForConditionalGeneration", "MobileLlamaForCausalLM")
class LlamaModel(Model):
model_arch = gguf.MODEL_ARCH.LLAMA

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if "vision_config" in self.hparams:

model_type = self.hparams.get("model_type", None)
self.vision_arch = None

# only tested with https://huggingface.co/llava-hf/llava-1.5-7b-hf
if "vision_config" in self.hparams and model_type == "llava":
self.vparams = self.hparams["vision_config"]
if self.vparams is not None:
self.v_tensor_map = gguf.get_tensor_name_map(gguf.MODEL_ARCH.LLAVA_VISION, self.vparams["num_hidden_layers"])
self.preprocessor_config = self.load_preprocessor_config(self.dir_model)
self.vision_arch = gguf.MODEL_ARCH.VISION_LLAVA

# only tested with https://huggingface.co/mtgv/MobileVLM_V2-1.7B
if "mm_vision_tower" in self.hparams and model_type == "mobilevlm":
vision_model_id = self.hparams["mm_vision_tower"]
self.vparams = AutoConfig.from_pretrained(vision_model_id).to_dict()["vision_config"]
self.preprocessor_config = self.load_preprocessor_config(vision_model_id)
self.vision_arch = gguf.MODEL_ARCH.VISION_MOBILEVLM

if self.vparams is not None and self.vision_arch is not None:
self.v_tensor_map = gguf.get_tensor_name_map(self.vision_arch, self.vparams["num_hidden_layers"])

def set_vocab(self):
try:
Expand Down Expand Up @@ -1631,23 +1649,31 @@ def set_vocab(self):
self.gguf_writer.add_add_bos_token(False)

# For vision model
if self.vparams is not None and self.preprocessor_config is not None:
if self.vparams is not None and self.preprocessor_config is not None and self.vision_arch is not None:
self.gguf_writer.add_vision_type("clip-vit")
self.gguf_writer.add_vision_image_size(self.vparams["image_size"])
self.gguf_writer.add_vision_patch_size(self.vparams["patch_size"])
self.gguf_writer.add_vision_clip_architecture("llava")
self.gguf_writer.add_vision_clip_architecture(gguf.MODEL_ARCH_NAMES[self.vision_arch])
self.gguf_writer.add_vision_clip_block_count(self.vparams["num_hidden_layers"])
self.gguf_writer.add_vision_clip_embedding_length(self.vparams["hidden_size"])
self.gguf_writer.add_vision_clip_feed_forward_length(self.vparams["intermediate_size"])
self.gguf_writer.add_vision_clip_head_count(self.vparams["num_attention_heads"])
self.gguf_writer.add_vision_clip_image_mean(self.preprocessor_config["image_mean"])
self.gguf_writer.add_vision_clip_image_std(self.preprocessor_config["image_std"])
self.gguf_writer.add_vision_clip_select_layer(self.hparams["vision_feature_layer"])
self.gguf_writer.add_vision_clip_patch_merge_type(gguf.CLIPPatchMergeType.FLAT)
max_pos_embd = (self.vparams["image_size"] // self.vparams["patch_size"])**2 + 1
self.gguf_writer.add_vision_clip_max_position_embeddings(max_pos_embd)
if "vision_feature_layer" in self.hparams:
self.gguf_writer.add_vision_clip_select_layer(self.hparams["vision_feature_layer"])
elif "mm_vision_select_layer" in self.hparams:
self.gguf_writer.add_vision_clip_select_layer(self.hparams["mm_vision_select_layer"])
else:
raise ValueError("gguf: can not find vision_feature_layer parameter.")
# TODO: should not hardcode these, but they are currently missing from config.json
self.gguf_writer.add_vision_clip_projector_type(gguf.constants.CLIPProjectorType.MLP)
if self.vision_arch == gguf.MODEL_ARCH.VISION_LLAVA:
self.gguf_writer.add_vision_clip_projector_type(gguf.constants.CLIPProjectorType.MLP)
if self.vision_arch == gguf.MODEL_ARCH.VISION_MOBILEVLM:
self.gguf_writer.add_vision_clip_projector_type(gguf.constants.CLIPProjectorType.LDPV2)
self.gguf_writer.add_vision_clip_layer_norm_epsilon(1e-05)

def set_gguf_parameters(self):
Expand Down Expand Up @@ -1683,6 +1709,8 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
# For vision model
if name.startswith("language_model"):
name = name.replace("language_model.", "")
else:
name = name.replace("model.vision_tower.", "")
if "post_layernorm" in name:
return [] # skip post_layernorm

Expand Down Expand Up @@ -2101,7 +2129,7 @@ def tensor_force_quant(self, name: str, new_name: str, bid: int | None, n_dims:
return n_dims > 1


@Model.register("MiniCPMForCausalLM")
@Model.register("MiniCPMForCausalLM", "MiniCPMV")
class MiniCPMModel(Model):
model_arch = gguf.MODEL_ARCH.MINICPM

Expand Down
34 changes: 30 additions & 4 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,8 @@ class MODEL_ARCH(IntEnum):
CHAMELEON = auto()
WAVTOKENIZER_DEC = auto()
# vision models
LLAVA_VISION = auto()
VISION_LLAVA = auto()
VISION_MOBILEVLM = auto()


class MODEL_TENSOR(IntEnum):
Expand Down Expand Up @@ -439,6 +440,8 @@ class MODEL_TENSOR(IntEnum):
POSNET_ATTN_OUT = auto()
# vision
V_MMPROJ = auto()
V_MMPROJ_MLP = auto()
V_MMPROJ_PEG = auto()
V_ENC_EMBD_CLS = auto()
V_ENC_EMBD_PATCH = auto()
V_ENC_EMBD_POS = auto()
Expand Down Expand Up @@ -512,6 +515,9 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.GRANITE_MOE: "granitemoe",
MODEL_ARCH.CHAMELEON: "chameleon",
MODEL_ARCH.WAVTOKENIZER_DEC: "wavtokenizer-dec",
# vision
MODEL_ARCH.VISION_LLAVA: "llava",
MODEL_ARCH.VISION_MOBILEVLM: "mobilevlm",
}

TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
Expand Down Expand Up @@ -641,6 +647,8 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.POSNET_ATTN_OUT: "posnet.{bid}.attn_output",
# vision
MODEL_TENSOR.V_MMPROJ: "v.mmproj_{bid}",
MODEL_TENSOR.V_MMPROJ_MLP: "v.mmproj.mlp.{bid}",
MODEL_TENSOR.V_MMPROJ_PEG: "v.mmproj.peg.{bid}",
MODEL_TENSOR.V_ENC_EMBD_CLS: "v.enc.embd.cls",
MODEL_TENSOR.V_ENC_EMBD_PATCH: "v.enc.embd.patch",
MODEL_TENSOR.V_ENC_EMBD_POS: "v.enc.embd.pos",
Expand Down Expand Up @@ -1595,7 +1603,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.POSNET_ATTN_V,
MODEL_TENSOR.POSNET_ATTN_OUT,
],
MODEL_ARCH.LLAVA_VISION: [
MODEL_ARCH.VISION_LLAVA: [
MODEL_TENSOR.V_MMPROJ,
MODEL_TENSOR.V_ENC_EMBD_CLS,
MODEL_TENSOR.V_ENC_EMBD_PATCH,
Expand All @@ -1611,6 +1619,23 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.V_PRE_NORM,
MODEL_TENSOR.V_POST_NORM,
],
MODEL_ARCH.VISION_MOBILEVLM: [
MODEL_TENSOR.V_MMPROJ_MLP,
MODEL_TENSOR.V_MMPROJ_PEG,
MODEL_TENSOR.V_ENC_EMBD_CLS,
MODEL_TENSOR.V_ENC_EMBD_PATCH,
MODEL_TENSOR.V_ENC_EMBD_POS,
MODEL_TENSOR.V_ENC_ATTN_Q,
MODEL_TENSOR.V_ENC_ATTN_K,
MODEL_TENSOR.V_ENC_ATTN_V,
MODEL_TENSOR.V_ENC_INPUT_NORM,
MODEL_TENSOR.V_ENC_OUTPUT,
MODEL_TENSOR.V_ENC_OUTPUT_NORM,
MODEL_TENSOR.V_ENC_FFN_UP,
MODEL_TENSOR.V_ENC_FFN_DOWN,
MODEL_TENSOR.V_PRE_NORM,
MODEL_TENSOR.V_POST_NORM,
],
# TODO
}

Expand Down Expand Up @@ -1693,11 +1718,12 @@ class PoolingType(IntEnum):


class CLIPProjectorType(Enum):
MLP = 'mlp'
MLP = 'mlp'
LDPV2 = 'ldpv2'


class CLIPPatchMergeType(Enum):
FLAT = 'flat'
FLAT = 'flat'
SPATIAL_UNPAD = 'spatial_unpad'


Expand Down
2 changes: 1 addition & 1 deletion gguf-py/gguf/gguf_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -876,7 +876,7 @@ def add_remove_extra_whitespaces(self, value: bool) -> None:

def add_precompiled_charsmap(self, charsmap: Sequence[bytes]) -> None:
self.add_array(Keys.Tokenizer.PRECOMPILED_CHARSMAP, charsmap)

def add_vision_type(self, value: str) -> None:
self.add_string(Keys.Vision.TYPE, value)

Expand Down
8 changes: 8 additions & 0 deletions gguf-py/gguf/tensor_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,6 +794,14 @@ class TensorNameMap:
"multi_modal_projector.linear_{bid}",
),

MODEL_TENSOR.V_MMPROJ_MLP: (
"model.mm_projector.mlp.mlp.{bid}",
),

MODEL_TENSOR.V_MMPROJ_PEG: (
"model.mm_projector.peg.peg.{bid}",
),

MODEL_TENSOR.V_ENC_EMBD_CLS: (
"vision_tower.vision_model.embeddings.class_embedding",
),
Expand Down
31 changes: 30 additions & 1 deletion src/llama-arch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {

static const std::map<vision_arch, const char *> VISION_ARCH_NAMES = {
{ VISION_ARCH_LLAVA, "llava" },
{ VISION_ARCH_MOBILEVLM, "mobilevlm" },
{ VISION_ARCH_UNKNOWN, "(unknown)" },
};

Expand Down Expand Up @@ -1345,7 +1346,27 @@ static const std::map<vision_arch, std::map<vision_tensor, const char *>> VISION
{ VISION_TENSOR_PRE_NORM, "v.pre_norm" },
{ VISION_TENSOR_POST_NORM, "v.post_norm" },
}
}
},
{
VISION_ARCH_MOBILEVLM,
{
{ VISION_TENSOR_MMPROJ_MLP, "v.mmproj.mlp.%d" },
{ VISION_TENSOR_MMPROJ_PEG, "v.mmproj.peg.%d" },
{ VISION_TENSOR_ENC_EMBD_CLS, "v.enc.embd.cls" },
{ VISION_TENSOR_ENC_EMBD_PATCH, "v.enc.embd.patch" },
{ VISION_TENSOR_ENC_EMBD_POS, "v.enc.embd.pos" },
{ VISION_TENSOR_ENC_ATTN_Q, "v.enc.blk.%d.attn_q" },
{ VISION_TENSOR_ENC_ATTN_K, "v.enc.blk.%d.attn_k" },
{ VISION_TENSOR_ENC_ATTN_V, "v.enc.blk.%d.attn_v" },
{ VISION_TENSOR_ENC_INPUT_NORM, "v.enc.blk.%d.input_norm" },
{ VISION_TENSOR_ENC_OUTPUT, "v.enc.blk.%d.output" },
{ VISION_TENSOR_ENC_OUTPUT_NORM, "v.enc.blk.%d.output_norm" },
{ VISION_TENSOR_ENC_FFN_UP, "v.enc.blk.%d.ffn_up" },
{ VISION_TENSOR_ENC_FFN_DOWN, "v.enc.blk.%d.ffn_down" },
{ VISION_TENSOR_PRE_NORM, "v.pre_norm" },
{ VISION_TENSOR_POST_NORM, "v.post_norm" },
}
},
};

static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
Expand Down Expand Up @@ -1499,6 +1520,10 @@ std::string LLM_KV::operator()(llm_kv kv) const {

template<>
std::string BASE_TN_IMPL<llm_arch, llm_tensor>::str() const {
if (LLM_TENSOR_NAMES.find(arch) == LLM_TENSOR_NAMES.end()) {
throw std::runtime_error(format("Cannot find tensor name mapping for arch %d", arch));
}

if (LLM_TENSOR_NAMES.at(arch).find(tensor) == LLM_TENSOR_NAMES.at(arch).end()) {
return "__missing__";
}
Expand All @@ -1515,6 +1540,10 @@ std::string BASE_TN_IMPL<llm_arch, llm_tensor>::str() const {

template<>
std::string BASE_TN_IMPL<vision_arch, vision_tensor>::str() const {
if (VISION_TENSOR_NAMES.find(arch) == VISION_TENSOR_NAMES.end()) {
throw std::runtime_error(format("Cannot find tensor name mapping for arch %d", arch));
}

if (VISION_TENSOR_NAMES.at(arch).find(tensor) == VISION_TENSOR_NAMES.at(arch).end()) {
return "__missing__";
}
Expand Down
3 changes: 3 additions & 0 deletions src/llama-arch.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ enum llm_arch {
enum vision_arch {
VISION_ARCH_UNKNOWN,
VISION_ARCH_LLAVA,
VISION_ARCH_MOBILEVLM,
};

enum llm_kv {
Expand Down Expand Up @@ -356,6 +357,8 @@ enum llm_tensor {

enum vision_tensor {
VISION_TENSOR_MMPROJ,
VISION_TENSOR_MMPROJ_MLP,
VISION_TENSOR_MMPROJ_PEG,
VISION_TENSOR_ENC_EMBD_CLS,
VISION_TENSOR_ENC_EMBD_PATCH,
VISION_TENSOR_ENC_EMBD_POS,
Expand Down
Loading

0 comments on commit d0068ef

Please sign in to comment.