From 6cabdda0df1a5d89255c3895dc74dfc0eb435048 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 18 Jan 2025 22:56:04 +0100 Subject: [PATCH] add back convert hf to gguf --- convert_hf_to_gguf.py | 70 +++++++++++++++++++++++++-- examples/server/server.cpp | 1 + gguf-py/gguf/constants.py | 86 ++++++++++++++++++++++++++++++++++ gguf-py/gguf/gguf_writer.py | 53 +++++++++++++++++++++ gguf-py/gguf/tensor_mapping.py | 58 +++++++++++++++++++++++ include/llama.h | 2 +- src/llama-vision.h | 2 +- 7 files changed, 266 insertions(+), 6 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 95f11204332eb..9e36cad61131c 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -17,6 +17,7 @@ from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Literal, Sequence, TypeVar, cast from itertools import chain +from transformers import AutoConfig import math import numpy as np import torch @@ -66,6 +67,12 @@ class Model: metadata_override: Path | None dir_model_card: Path + # for vision model + preprocessor_config: dict[str, Any] | None = None + vparams: dict[str, Any] | None = None + v_tensor_map: gguf.TensorNameMap + v_tensor_names: set[str] | None + # subclasses should define this! model_arch: gguf.MODEL_ARCH @@ -95,6 +102,7 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, self.metadata_override = metadata_override self.model_name = model_name self.dir_model_card = dir_model # overridden in convert_lora_to_gguf.py + self.preprocessor_config = self.load_preprocessor_config(self.dir_model) # Apply heuristics to figure out typical tensor encoding based on first layer tensor encoding type if self.ftype == gguf.LlamaFileType.GUESSED: @@ -210,9 +218,13 @@ def match_model_tensor_name(self, name: str, key: gguf.MODEL_TENSOR, bid: int | def map_tensor_name(self, name: str, try_suffixes: Sequence[str] = (".weight", ".bias")) -> str: new_name = self.tensor_map.get_name(key=name, try_suffixes=try_suffixes) - if new_name is None: + new_name_vision = self.v_tensor_map.get_name(key=name, try_suffixes=try_suffixes) + if new_name is not None: + return new_name + elif new_name_vision is not None: + return new_name_vision + else: raise ValueError(f"Can not map tensor {name!r}") - return new_name def set_gguf_parameters(self): self.gguf_writer.add_block_count(self.block_count) @@ -466,7 +478,24 @@ def get_model_part_names(dir_model: Path, prefix: str, suffix: str) -> list[str] @staticmethod def load_hparams(dir_model: Path): with open(dir_model / "config.json", "r", encoding="utf-8") as f: - return json.load(f) + hparams = json.load(f) + if "text_config" in hparams: + text_config = hparams["text_config"] + # for example, llava-1.5-7b-hf misses the language model config, need to retrieve it via model ID + if "_name_or_path" in text_config: + text_config = AutoConfig.from_pretrained(text_config["_name_or_path"]).to_dict() + hparams = {**text_config, **hparams} + return hparams + + @staticmethod + def load_preprocessor_config(dir_model: Path): + # TODO: this varies vastly among models, need to handle more cases in the future + file_path = dir_model / "preprocessor_config.json" + if os.path.exists(file_path): + with open(file_path, "r", encoding="utf-8") as f: + return json.load(f) + else: + return None @classmethod def register(cls, *names: str) -> Callable[[AnyModel], AnyModel]: @@ -1557,10 +1586,17 @@ def prepare_tensors(self): raise ValueError(f"Unprocessed norms: {norms}") -@Model.register("LLaMAForCausalLM", "LlamaForCausalLM", "MistralForCausalLM", "MixtralForCausalLM") +@Model.register("LLaMAForCausalLM", "LlamaForCausalLM", "MistralForCausalLM", "MixtralForCausalLM", "LlavaForConditionalGeneration") class LlamaModel(Model): model_arch = gguf.MODEL_ARCH.LLAMA + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + if "vision_config" in self.hparams: + self.vparams = self.hparams["vision_config"] + if self.vparams is not None: + self.v_tensor_map = gguf.get_tensor_name_map(gguf.MODEL_ARCH.LLAVA_VISION, self.vparams["num_hidden_layers"]) + def set_vocab(self): try: self._set_vocab_sentencepiece() @@ -1594,6 +1630,26 @@ def set_vocab(self): if self.hparams.get("vocab_size", 32000) == 49152: self.gguf_writer.add_add_bos_token(False) + # For vision model + if self.vparams is not None and self.preprocessor_config is not None: + self.gguf_writer.add_vision_type("clip-vit") + self.gguf_writer.add_vision_image_size(self.vparams["image_size"]) + self.gguf_writer.add_vision_patch_size(self.vparams["patch_size"]) + self.gguf_writer.add_vision_clip_architecture("llava") + self.gguf_writer.add_vision_clip_block_count(self.vparams["num_hidden_layers"]) + self.gguf_writer.add_vision_clip_embedding_length(self.vparams["hidden_size"]) + self.gguf_writer.add_vision_clip_feed_forward_length(self.vparams["intermediate_size"]) + self.gguf_writer.add_vision_clip_head_count(self.vparams["num_attention_heads"]) + self.gguf_writer.add_vision_clip_image_mean(self.preprocessor_config["image_mean"]) + self.gguf_writer.add_vision_clip_image_std(self.preprocessor_config["image_std"]) + self.gguf_writer.add_vision_clip_select_layer(self.hparams["vision_feature_layer"]) + self.gguf_writer.add_vision_clip_patch_merge_type(gguf.CLIPPatchMergeType.FLAT) + max_pos_embd = (self.vparams["image_size"] // self.vparams["patch_size"])**2 + 1 + self.gguf_writer.add_vision_clip_max_position_embeddings(max_pos_embd) + # TODO: should not hardcode these, but they are currently missing from config.json + self.gguf_writer.add_vision_clip_projector_type(gguf.constants.CLIPProjectorType.MLP) + self.gguf_writer.add_vision_clip_layer_norm_epsilon(1e-05) + def set_gguf_parameters(self): super().set_gguf_parameters() hparams = self.hparams @@ -1624,6 +1680,12 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter n_head = self.hparams["num_attention_heads"] n_kv_head = self.hparams.get("num_key_value_heads") + # For vision model + if name.startswith("language_model"): + name = name.replace("language_model.", "") + if "post_layernorm" in name: + return [] # skip post_layernorm + if name.endswith(("q_proj.weight", "q_proj.bias")): data_torch = LlamaModel.permute(data_torch, n_head, n_head) if name.endswith(("k_proj.weight", "k_proj.bias")): diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 64c0c4ef68f13..83aa946e2a64c 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2949,6 +2949,7 @@ struct server_context { batch.n_seq_id + i, batch.seq_id + i, batch.logits + i, + nullptr, }; const int ret = llama_decode(ctx, batch_view); diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 8fe84df21ea20..411c89e7f5373 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -202,6 +202,9 @@ class Tokenizer: FIM_PAD_ID = "tokenizer.ggml.fim_pad_token_id" FIM_REP_ID = "tokenizer.ggml.fim_rep_token_id" FIM_SEP_ID = "tokenizer.ggml.fim_sep_token_id" + # Vision models + IMAGE_START_ID = "tokenizer.ggml.image_start_token_id" + IMAGE_END_ID = "tokenizer.ggml.image_end_token_id" # deprecated: PREFIX_ID = "tokenizer.ggml.prefix_token_id" SUFFIX_ID = "tokenizer.ggml.suffix_token_id" @@ -211,6 +214,31 @@ class Adapter: TYPE = "adapter.type" LORA_ALPHA = "adapter.lora.alpha" + class Vision: + # only support vision.type = "clip-vit" for now + TYPE = "vision.type" + IMAGE_SIZE = "vision.image_size" + PATCH_SIZE = "vision.patch_size" + IMAGE_MEAN = "vision.image_mean" + IMAGE_STD = "vision.image_std" + + class Clip: + ARCHITECTURE = "vision.clip.architecture" + CONTEXT_LENGTH = "vision.clip.context_length" + EMBEDDING_LENGTH = "vision.clip.embedding_length" + BLOCK_COUNT = "vision.clip.block_count" + FEED_FORWARD_LENGTH = "vision.clip.feed_forward_length" + PROJECTION_TYPE = "vision.clip.projection_type" + PROJECTION_DIM = "vision.clip.projection_dim" + USE_GELU = "vision.clip.use_gelu" + MAX_POS_EMBEDDING = "vision.clip.max_position_embeddings" + MAX_SLICES = "vision.clip.max_slices" + PROJECTOR_TYPE = "vision.clip.projector_type" + SELECT_LAYER = "vision.clip.select_layer" + PATCH_MERGE_TYPE = "vision.clip.patch_merge_type" + HEAD_COUNT = "vision.clip.attention.head_count" + LAYERNORM_EPS = "vision.clip.attention.layer_norm_epsilon" + # # recommended mapping of model tensor names for storage in gguf # @@ -279,6 +307,8 @@ class MODEL_ARCH(IntEnum): GRANITE_MOE = auto() CHAMELEON = auto() WAVTOKENIZER_DEC = auto() + # vision models + LLAVA_VISION = auto() class MODEL_TENSOR(IntEnum): @@ -390,6 +420,7 @@ class MODEL_TENSOR(IntEnum): ENC_OUTPUT_NORM = auto() CLS = auto() # classifier CLS_OUT = auto() # classifier output projection + # wavtokenizer CONV1D = auto() CONVNEXT_DW = auto() CONVNEXT_NORM = auto() @@ -406,6 +437,21 @@ class MODEL_TENSOR(IntEnum): POSNET_ATTN_K = auto() POSNET_ATTN_V = auto() POSNET_ATTN_OUT = auto() + # vision + V_MMPROJ = auto() + V_ENC_EMBD_CLS = auto() + V_ENC_EMBD_PATCH = auto() + V_ENC_EMBD_POS = auto() + V_ENC_ATTN_Q = auto() + V_ENC_ATTN_K = auto() + V_ENC_ATTN_V = auto() + V_ENC_INPUT_NORM = auto() + V_ENC_OUTPUT = auto() + V_ENC_OUTPUT_NORM = auto() + V_ENC_FFN_UP = auto() + V_ENC_FFN_DOWN = auto() + V_PRE_NORM = auto() + V_POST_NORM = auto() MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { @@ -593,6 +639,21 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.POSNET_ATTN_K: "posnet.{bid}.attn_k", MODEL_TENSOR.POSNET_ATTN_V: "posnet.{bid}.attn_v", MODEL_TENSOR.POSNET_ATTN_OUT: "posnet.{bid}.attn_output", + # vision + MODEL_TENSOR.V_MMPROJ: "v.mmproj_{bid}", + MODEL_TENSOR.V_ENC_EMBD_CLS: "v.enc.embd.cls", + MODEL_TENSOR.V_ENC_EMBD_PATCH: "v.enc.embd.patch", + MODEL_TENSOR.V_ENC_EMBD_POS: "v.enc.embd.pos", + MODEL_TENSOR.V_ENC_ATTN_Q: "v.enc.blk.{bid}.attn_q", + MODEL_TENSOR.V_ENC_ATTN_K: "v.enc.blk.{bid}.attn_k", + MODEL_TENSOR.V_ENC_ATTN_V: "v.enc.blk.{bid}.attn_v", + MODEL_TENSOR.V_ENC_INPUT_NORM: "v.enc.blk.{bid}.input_norm", + MODEL_TENSOR.V_ENC_OUTPUT: "v.enc.blk.{bid}.output", + MODEL_TENSOR.V_ENC_OUTPUT_NORM: "v.enc.blk.{bid}.output_norm", + MODEL_TENSOR.V_ENC_FFN_UP: "v.enc.blk.{bid}.ffn_up", + MODEL_TENSOR.V_ENC_FFN_DOWN: "v.enc.blk.{bid}.ffn_down", + MODEL_TENSOR.V_PRE_NORM: "v.pre_norm", + MODEL_TENSOR.V_POST_NORM: "v.post_norm", } MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { @@ -1534,6 +1595,22 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.POSNET_ATTN_V, MODEL_TENSOR.POSNET_ATTN_OUT, ], + MODEL_ARCH.LLAVA_VISION: [ + MODEL_TENSOR.V_MMPROJ, + MODEL_TENSOR.V_ENC_EMBD_CLS, + MODEL_TENSOR.V_ENC_EMBD_PATCH, + MODEL_TENSOR.V_ENC_EMBD_POS, + MODEL_TENSOR.V_ENC_ATTN_Q, + MODEL_TENSOR.V_ENC_ATTN_K, + MODEL_TENSOR.V_ENC_ATTN_V, + MODEL_TENSOR.V_ENC_INPUT_NORM, + MODEL_TENSOR.V_ENC_OUTPUT, + MODEL_TENSOR.V_ENC_OUTPUT_NORM, + MODEL_TENSOR.V_ENC_FFN_UP, + MODEL_TENSOR.V_ENC_FFN_DOWN, + MODEL_TENSOR.V_PRE_NORM, + MODEL_TENSOR.V_POST_NORM, + ], # TODO } @@ -1615,6 +1692,15 @@ class PoolingType(IntEnum): CLS = 2 +class CLIPProjectorType(Enum): + MLP = 'mlp' + + +class CLIPPatchMergeType(Enum): + FLAT = 'flat' + SPATIAL_UNPAD = 'spatial_unpad' + + class GGMLQuantizationType(IntEnum): F32 = 0 F16 = 1 diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 080d2b9dce5cb..5438acd06132b 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -27,6 +27,8 @@ PoolingType, TokenType, ExpertGatingFuncType, + CLIPPatchMergeType, + CLIPProjectorType, ) from .quants import quant_shape_from_byte_shape @@ -874,6 +876,57 @@ def add_remove_extra_whitespaces(self, value: bool) -> None: def add_precompiled_charsmap(self, charsmap: Sequence[bytes]) -> None: self.add_array(Keys.Tokenizer.PRECOMPILED_CHARSMAP, charsmap) + + def add_vision_type(self, value: str) -> None: + self.add_string(Keys.Vision.TYPE, value) + + def add_vision_image_size(self, value: int) -> None: + self.add_uint32(Keys.Vision.IMAGE_SIZE, value) + + def add_vision_patch_size(self, value: int) -> None: + self.add_uint32(Keys.Vision.PATCH_SIZE, value) + + def add_vision_clip_architecture(self, value: str) -> None: + self.add_string(Keys.Vision.Clip.ARCHITECTURE, value) + + def add_vision_clip_context_length(self, value: int) -> None: + self.add_uint32(Keys.Vision.Clip.CONTEXT_LENGTH, value) + + def add_vision_clip_embedding_length(self, value: int) -> None: + self.add_uint32(Keys.Vision.Clip.EMBEDDING_LENGTH, value) + + def add_vision_clip_block_count(self, value: int) -> None: + self.add_uint32(Keys.Vision.Clip.BLOCK_COUNT, value) + + def add_vision_clip_feed_forward_length(self, value: int) -> None: + self.add_uint32(Keys.Vision.Clip.FEED_FORWARD_LENGTH, value) + + def add_vision_clip_head_count(self, value: int) -> None: + self.add_uint32(Keys.Vision.Clip.HEAD_COUNT, value) + + def add_vision_clip_max_position_embeddings(self, value: int) -> None: + self.add_uint32(Keys.Vision.Clip.MAX_POS_EMBEDDING, value) + + def add_vision_clip_projector_type(self, value: CLIPProjectorType) -> None: + self.add_string(Keys.Vision.Clip.PROJECTOR_TYPE, value.value) + + def add_vision_clip_max_slices(self, value: int) -> None: + self.add_uint32(Keys.Vision.Clip.MAX_SLICES, value) + + def add_vision_clip_select_layer(self, value: int) -> None: + self.add_int32(Keys.Vision.Clip.SELECT_LAYER, value) + + def add_vision_clip_patch_merge_type(self, value: CLIPPatchMergeType) -> None: + self.add_string(Keys.Vision.Clip.PATCH_MERGE_TYPE, value.value) + + def add_vision_clip_layer_norm_epsilon(self, value: float) -> None: + self.add_float32(Keys.Vision.Clip.LAYERNORM_EPS, value) + + def add_vision_clip_image_mean(self, value: Sequence[float]) -> None: + self.add_array(Keys.Vision.IMAGE_MEAN, value) + + def add_vision_clip_image_std(self, value: Sequence[float]) -> None: + self.add_array(Keys.Vision.IMAGE_STD, value) def add_chat_template(self, value: str | Sequence[Mapping[str, str]]) -> None: if not isinstance(value, str): diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 617791e240b60..813f8f7e052ce 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -787,6 +787,64 @@ class TensorNameMap: MODEL_TENSOR.POSNET_ATTN_OUT: ( "backbone.posnet.{bid}.proj_out", # wavtokenizer ), + + ############################################################################# + + MODEL_TENSOR.V_MMPROJ: ( + "multi_modal_projector.linear_{bid}", + ), + + MODEL_TENSOR.V_ENC_EMBD_CLS: ( + "vision_tower.vision_model.embeddings.class_embedding", + ), + + MODEL_TENSOR.V_ENC_EMBD_PATCH: ( + "vision_tower.vision_model.embeddings.patch_embedding", + ), + + MODEL_TENSOR.V_ENC_EMBD_POS: ( + "vision_tower.vision_model.embeddings.position_embedding", + ), + + MODEL_TENSOR.V_ENC_ATTN_Q: ( + "vision_tower.vision_model.encoder.layers.{bid}.self_attn.q_proj", + ), + + MODEL_TENSOR.V_ENC_ATTN_K: ( + "vision_tower.vision_model.encoder.layers.{bid}.self_attn.k_proj", + ), + + MODEL_TENSOR.V_ENC_ATTN_V: ( + "vision_tower.vision_model.encoder.layers.{bid}.self_attn.v_proj", + ), + + MODEL_TENSOR.V_ENC_INPUT_NORM: ( + "vision_tower.vision_model.encoder.layers.{bid}.layer_norm1", + ), + + MODEL_TENSOR.V_ENC_OUTPUT: ( + "vision_tower.vision_model.encoder.layers.{bid}.self_attn.out_proj", + ), + + MODEL_TENSOR.V_ENC_OUTPUT_NORM: ( + "vision_tower.vision_model.encoder.layers.{bid}.layer_norm2", + ), + + MODEL_TENSOR.V_ENC_FFN_UP: ( + "vision_tower.vision_model.encoder.layers.{bid}.mlp.fc1", + ), + + MODEL_TENSOR.V_ENC_FFN_DOWN: ( + "vision_tower.vision_model.encoder.layers.{bid}.mlp.fc2", + ), + + MODEL_TENSOR.V_PRE_NORM: ( + "vision_tower.vision_model.pre_layrnorm", + ), + + MODEL_TENSOR.V_POST_NORM: ( + "vision_tower.vision_model.post_layernorm", + ), } # architecture-specific block mappings diff --git a/include/llama.h b/include/llama.h index 5013e96e78825..bd8e696585693 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1292,7 +1292,7 @@ extern "C" { // Encode patches into embeddings LLAMA_API int32_t llama_vision_encode(struct llama_context * ctx, struct llama_vision_patches * p); - LLAMA_API struct ggml_tensor * llama_vision_get_output_tensor(llama_context * ctx); + LLAMA_API struct ggml_tensor * llama_vision_get_output_tensor(struct llama_context * ctx); // // Model split diff --git a/src/llama-vision.h b/src/llama-vision.h index 56c6b49c96ed9..ced58dd0b88ca 100644 --- a/src/llama-vision.h +++ b/src/llama-vision.h @@ -40,7 +40,7 @@ struct clip_hparams { std::array image_mean; std::array image_std; - std::array image_grid_pinpoints; + std::array image_grid_pinpoints; // TODO: should this be array of (x, y) pairs? int32_t image_crop_resolution; };