From 56d2f72324845cf36b95d4c9bb5cf779f3d03134 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sat, 8 Feb 2025 22:12:53 +0100 Subject: [PATCH] [RFC] [Mistral] FP8 format (#10130) Signed-off-by: mgoin Co-authored-by: mgoin Signed-off-by: SzymonOzog --- vllm/model_executor/models/llama.py | 20 ++++++++-- vllm/model_executor/models/pixtral.py | 7 +++- vllm/transformers_utils/config.py | 37 ++++++++++++++++--- vllm/transformers_utils/tokenizers/mistral.py | 3 +- 4 files changed, 55 insertions(+), 12 deletions(-) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 866c69234753c..2ff52dd789125 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -467,6 +467,9 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): mistral_mapping = { "layers": "model.layers", "attention": "self_attn", + "qscale_act": "input_scale", + "qscale_weight": "weight_scale", + "kv_fake_quantizer.qscale_act": "kv_scale", "wq": "q_proj", "wk": "k_proj", "wv": "v_proj", @@ -590,15 +593,24 @@ def permute(w: torch.Tensor, n_heads: int): modules = name.split(".") # rotary embeds should be sliced - if "wk" in modules: + if "wk" in modules and modules[-1] == "weight": loaded_weight = permute(loaded_weight, self.config.num_key_value_heads) - elif "wq" in modules: + elif "wq" in modules and modules[-1] == "weight": loaded_weight = permute(loaded_weight, self.config.num_attention_heads) - for item in modules: - if item in mapping and mapping[item] not in name: + num_modules = len(modules) + for i in range(num_modules): + item = modules[i] + next_item = modules[i + 1] if i < num_modules - 1 else None + + combined_item = (f"{item}.{next_item}" + if next_item is not None else None) + + if combined_item in mapping: + name = name.replace(combined_item, mapping[combined_item]) + elif item in mapping and mapping[item] not in name: name = name.replace(item, mapping[item]) return name, loaded_weight diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 003e9c84c1c0a..e78e8d62cc47c 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -54,8 +54,11 @@ def get_max_pixtral_image_tokens(ctx: InputContext): tokenizer_mode=ctx.model_config.tokenizer_mode) mm_encoder = tokenizer.instruct.mm_encoder - max_image_size = mm_encoder.mm_config.max_image_size - image_patch_size = mm_encoder.mm_config.image_patch_size + image_config = mm_encoder.mm_config if hasattr( + mm_encoder, "mm_config") else mm_encoder.image_config + + max_image_size = image_config.max_image_size + image_patch_size = image_config.image_patch_size return ((max_image_size // image_patch_size)**2) diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index fb5cc3ec0722b..42b45e10e3f25 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -4,7 +4,7 @@ import json import os from pathlib import Path -from typing import Any, Dict, Optional, Type, Union +from typing import Any, Dict, Literal, Optional, Type, Union import huggingface_hub from huggingface_hub import (file_exists, hf_hub_download, list_repo_files, @@ -554,7 +554,8 @@ def recurse_elems(elem: Any): for key, value in elem.items(): key = config_mapping.get(key, key) config_dict[key] = recurse_elems(value) - return PretrainedConfig(**config_dict) + + return config_dict else: return elem @@ -566,12 +567,30 @@ def recurse_elems(elem: Any): config_dict["max_position_embeddings"] = config_dict.get( "max_position_embeddings", 128_000) + if config_dict.get("quantization") is not None: + quantization = config_dict.get("quantization", {}) + if quantization.get("qformat_weight") == "fp8_e4m3": + # This maps to the FP8 static per-tensor quantization scheme + quantization_config = { + "quant_method": "fp8", + "activation_scheme": "static" + } + else: + raise ValueError( + f"Found unknown quantization='{quantization}' in config") + + config_dict["quantization_config"] = quantization_config + + config_type: Literal["text", + "multimodal"] = "multimodal" if config_dict.get( + "vision_encoder") is not None else "text" + if config_dict.get("moe") is not None: config_dict["architectures"] = ["MixtralForCausalLM"] else: config_dict["architectures"] = ["MistralForCausalLM"] - if config_dict.get("vision_encoder") is not None: + if config_type == "multimodal": multimodal_config = config_dict.pop("vision_encoder") config_dict = { @@ -583,8 +602,16 @@ def recurse_elems(elem: Any): config_dict.update(kwargs) - config = recurse_elems(config_dict) - return config + config_dict = recurse_elems(config_dict) + + # transform to HF config format + if config_type == "multimodal": + config_dict["text_config"] = PretrainedConfig( + **config_dict["text_config"]) + config_dict["vision_config"] = PretrainedConfig( + **config_dict["vision_config"]) + + return PretrainedConfig(**config_dict) def get_hf_image_processor_config( diff --git a/vllm/transformers_utils/tokenizers/mistral.py b/vllm/transformers_utils/tokenizers/mistral.py index 7a1dba424464e..8d96fcd278e67 100644 --- a/vllm/transformers_utils/tokenizers/mistral.py +++ b/vllm/transformers_utils/tokenizers/mistral.py @@ -88,7 +88,8 @@ def list_local_repo_files(repo_id: str, revision: Optional[str]) -> List[str]: def find_tokenizer_file(files: List[str]): - file_pattern = re.compile(r"^tokenizer\.model\.v.*$|^tekken\.json$") + file_pattern = re.compile( + r"^tokenizer\.model\.v.*$|^tekken\.json$|^tokenizer\.mm\.model\.v.*$") matched_files = [file for file in files if file_pattern.match(file)] if len(matched_files) > 1: