Skip to content

Commit

Permalink
[RFC] [Mistral] FP8 format (vllm-project#10130)
Browse files Browse the repository at this point in the history
Signed-off-by: mgoin <mgoin64@gmail.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
Signed-off-by: SzymonOzog <szymon.ozog@aleph-alpha.com>
  • Loading branch information
2 people authored and SzymonOzog committed Feb 12, 2025
1 parent cf8ea05 commit 56d2f72
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 12 deletions.
20 changes: 16 additions & 4 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,9 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
mistral_mapping = {
"layers": "model.layers",
"attention": "self_attn",
"qscale_act": "input_scale",
"qscale_weight": "weight_scale",
"kv_fake_quantizer.qscale_act": "kv_scale",
"wq": "q_proj",
"wk": "k_proj",
"wv": "v_proj",
Expand Down Expand Up @@ -590,15 +593,24 @@ def permute(w: torch.Tensor, n_heads: int):
modules = name.split(".")

# rotary embeds should be sliced
if "wk" in modules:
if "wk" in modules and modules[-1] == "weight":
loaded_weight = permute(loaded_weight,
self.config.num_key_value_heads)
elif "wq" in modules:
elif "wq" in modules and modules[-1] == "weight":
loaded_weight = permute(loaded_weight,
self.config.num_attention_heads)

for item in modules:
if item in mapping and mapping[item] not in name:
num_modules = len(modules)
for i in range(num_modules):
item = modules[i]
next_item = modules[i + 1] if i < num_modules - 1 else None

combined_item = (f"{item}.{next_item}"
if next_item is not None else None)

if combined_item in mapping:
name = name.replace(combined_item, mapping[combined_item])
elif item in mapping and mapping[item] not in name:
name = name.replace(item, mapping[item])

return name, loaded_weight
7 changes: 5 additions & 2 deletions vllm/model_executor/models/pixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,11 @@ def get_max_pixtral_image_tokens(ctx: InputContext):
tokenizer_mode=ctx.model_config.tokenizer_mode)
mm_encoder = tokenizer.instruct.mm_encoder

max_image_size = mm_encoder.mm_config.max_image_size
image_patch_size = mm_encoder.mm_config.image_patch_size
image_config = mm_encoder.mm_config if hasattr(
mm_encoder, "mm_config") else mm_encoder.image_config

max_image_size = image_config.max_image_size
image_patch_size = image_config.image_patch_size

return ((max_image_size // image_patch_size)**2)

Expand Down
37 changes: 32 additions & 5 deletions vllm/transformers_utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import json
import os
from pathlib import Path
from typing import Any, Dict, Optional, Type, Union
from typing import Any, Dict, Literal, Optional, Type, Union

import huggingface_hub
from huggingface_hub import (file_exists, hf_hub_download, list_repo_files,
Expand Down Expand Up @@ -554,7 +554,8 @@ def recurse_elems(elem: Any):
for key, value in elem.items():
key = config_mapping.get(key, key)
config_dict[key] = recurse_elems(value)
return PretrainedConfig(**config_dict)

return config_dict
else:
return elem

Expand All @@ -566,12 +567,30 @@ def recurse_elems(elem: Any):
config_dict["max_position_embeddings"] = config_dict.get(
"max_position_embeddings", 128_000)

if config_dict.get("quantization") is not None:
quantization = config_dict.get("quantization", {})
if quantization.get("qformat_weight") == "fp8_e4m3":
# This maps to the FP8 static per-tensor quantization scheme
quantization_config = {
"quant_method": "fp8",
"activation_scheme": "static"
}
else:
raise ValueError(
f"Found unknown quantization='{quantization}' in config")

config_dict["quantization_config"] = quantization_config

config_type: Literal["text",
"multimodal"] = "multimodal" if config_dict.get(
"vision_encoder") is not None else "text"

if config_dict.get("moe") is not None:
config_dict["architectures"] = ["MixtralForCausalLM"]
else:
config_dict["architectures"] = ["MistralForCausalLM"]

if config_dict.get("vision_encoder") is not None:
if config_type == "multimodal":
multimodal_config = config_dict.pop("vision_encoder")

config_dict = {
Expand All @@ -583,8 +602,16 @@ def recurse_elems(elem: Any):

config_dict.update(kwargs)

config = recurse_elems(config_dict)
return config
config_dict = recurse_elems(config_dict)

# transform to HF config format
if config_type == "multimodal":
config_dict["text_config"] = PretrainedConfig(
**config_dict["text_config"])
config_dict["vision_config"] = PretrainedConfig(
**config_dict["vision_config"])

return PretrainedConfig(**config_dict)


def get_hf_image_processor_config(
Expand Down
3 changes: 2 additions & 1 deletion vllm/transformers_utils/tokenizers/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ def list_local_repo_files(repo_id: str, revision: Optional[str]) -> List[str]:


def find_tokenizer_file(files: List[str]):
file_pattern = re.compile(r"^tokenizer\.model\.v.*$|^tekken\.json$")
file_pattern = re.compile(
r"^tokenizer\.model\.v.*$|^tekken\.json$|^tokenizer\.mm\.model\.v.*$")

matched_files = [file for file in files if file_pattern.match(file)]
if len(matched_files) > 1:
Expand Down

0 comments on commit 56d2f72

Please sign in to comment.