-
-
Notifications
You must be signed in to change notification settings - Fork 6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Model] Add support for GraniteMoeShared models #13313
base: main
Are you sure you want to change the base?
Changes from 8 commits
75835e6
d1ad12d
e422449
d005fae
33959a8
f7558e3
614e1a9
b2d9e45
4b0edeb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,355 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
"""Inference-only GraniteMoeShared model. | ||
|
||
The architecture is the same as granitemoe but with the addition of shared | ||
experts. | ||
""" | ||
from typing import Iterable, List, Optional, Set, Tuple | ||
|
||
import torch | ||
from torch import nn | ||
from transformers.models.granitemoeshared import GraniteMoeSharedConfig | ||
|
||
from vllm.attention import AttentionMetadata | ||
from vllm.compilation.decorators import support_torch_compile | ||
from vllm.config import CacheConfig, VllmConfig | ||
from vllm.distributed import get_pp_group | ||
from vllm.model_executor.layers.activation import SiluAndMul | ||
from vllm.model_executor.layers.layernorm import RMSNorm | ||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, | ||
RowParallelLinear) | ||
from vllm.model_executor.layers.logits_processor import LogitsProcessor | ||
from vllm.model_executor.layers.quantization.base_config import ( | ||
QuantizationConfig) | ||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler | ||
from vllm.model_executor.layers.vocab_parallel_embedding import ( | ||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) | ||
from vllm.model_executor.sampling_metadata import SamplingMetadata | ||
from vllm.sequence import IntermediateTensors | ||
|
||
from . import mixtral | ||
from .granitemoe import GraniteMoeAttention, GraniteMoeMoE | ||
from .interfaces import SupportsLoRA | ||
from .utils import make_layers, maybe_prefix | ||
|
||
|
||
class GraniteMoeSharedMLP(nn.Module): | ||
|
||
def __init__( | ||
self, | ||
config: GraniteMoeSharedConfig, | ||
quant_config: Optional[QuantizationConfig] = None, | ||
prefix: str = "", | ||
): | ||
super().__init__() | ||
|
||
self.input_size = config.hidden_size | ||
self.hidden_size = config.shared_intermediate_size | ||
self.input_linear = MergedColumnParallelLinear( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. QQ: why doesn't input_linear support LoRA? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for taking a look! Honestly, I don't know what is required for a layer to support LoRA... I presume that there is no reason for a simple linear layer not to, but do please let me know if there are reasons I would need to investigate 😅 I added |
||
input_size=self.input_size, | ||
output_sizes=[self.hidden_size] * 2, | ||
bias=False, | ||
quant_config=quant_config, | ||
prefix=f"{prefix}.input_linear") | ||
self.output_linear = RowParallelLinear( | ||
tjohnson31415 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self.hidden_size, | ||
self.input_size, | ||
bias=False, | ||
quant_config=quant_config, | ||
prefix=f"{prefix}.output_linear") | ||
if config.hidden_act != "silu": | ||
raise ValueError(f"Unsupported activation: {config.hidden_act}. " | ||
"Only silu is supported for now.") | ||
self.act_fn = SiluAndMul() | ||
|
||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | ||
hidden_states, _ = self.input_linear(hidden_states) | ||
hidden_states = self.act_fn(hidden_states) | ||
hidden_states, _ = self.output_linear(hidden_states) | ||
return hidden_states | ||
|
||
|
||
class GraniteMoeSharedDecoderLayer(nn.Module): | ||
|
||
def __init__( | ||
self, | ||
config: GraniteMoeSharedConfig, | ||
cache_config: Optional[CacheConfig] = None, | ||
quant_config: Optional[QuantizationConfig] = None, | ||
prefix: str = "", | ||
) -> None: | ||
super().__init__() | ||
self.hidden_size = config.hidden_size | ||
# Requires transformers > 4.32.0 | ||
rope_theta = getattr(config, "rope_theta", 10000) | ||
self.self_attn = GraniteMoeAttention( | ||
hidden_size=self.hidden_size, | ||
num_heads=config.num_attention_heads, | ||
max_position=config.max_position_embeddings, | ||
num_kv_heads=config.num_key_value_heads, | ||
rope_theta=rope_theta, | ||
cache_config=cache_config, | ||
quant_config=quant_config, | ||
prefix=f"{prefix}.self_attn", | ||
attention_multiplier=config.attention_multiplier) | ||
self.block_sparse_moe = GraniteMoeMoE( | ||
num_experts=config.num_local_experts, | ||
top_k=config.num_experts_per_tok, | ||
hidden_size=config.hidden_size, | ||
intermediate_size=config.intermediate_size, | ||
quant_config=quant_config, | ||
prefix=f"{prefix}.block_sparse_moe") | ||
self.shared_mlp = None if \ | ||
getattr(config, 'shared_intermediate_size', 0) == 0 \ | ||
else GraniteMoeSharedMLP( | ||
config, | ||
quant_config=quant_config, | ||
prefix=f"{prefix}.shared_mlp" | ||
) | ||
|
||
self.input_layernorm = RMSNorm(config.hidden_size, | ||
eps=config.rms_norm_eps) | ||
self.post_attention_layernorm = RMSNorm(config.hidden_size, | ||
eps=config.rms_norm_eps) | ||
|
||
self.residual_multiplier = config.residual_multiplier | ||
|
||
def forward( | ||
self, | ||
positions: torch.Tensor, | ||
hidden_states: torch.Tensor, | ||
kv_cache: torch.Tensor, | ||
attn_metadata: AttentionMetadata, | ||
) -> torch.Tensor: | ||
# Self Attention | ||
residual = hidden_states | ||
hidden_states = self.input_layernorm(hidden_states) | ||
hidden_states = self.self_attn( | ||
positions=positions, | ||
hidden_states=hidden_states, | ||
kv_cache=kv_cache, | ||
attn_metadata=attn_metadata, | ||
) | ||
hidden_states = residual + hidden_states * self.residual_multiplier | ||
residual = hidden_states | ||
hidden_states = self.post_attention_layernorm(hidden_states) | ||
if self.shared_mlp is None: | ||
hidden_states = self.block_sparse_moe(hidden_states) | ||
else: | ||
# create a copy since block_sparse_moe modifies in-place | ||
moe_hidden_states = hidden_states.clone() | ||
moe_hidden_states = self.block_sparse_moe(moe_hidden_states) | ||
hidden_states = moe_hidden_states + self.shared_mlp(hidden_states) | ||
del moe_hidden_states | ||
hidden_states = residual + hidden_states * self.residual_multiplier | ||
|
||
return hidden_states | ||
|
||
|
||
@support_torch_compile | ||
class GraniteMoeSharedModel(nn.Module): | ||
|
||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): | ||
super().__init__() | ||
|
||
config = vllm_config.model_config.hf_config | ||
cache_config = vllm_config.cache_config | ||
quant_config = vllm_config.quant_config | ||
lora_config = vllm_config.lora_config | ||
|
||
self.padding_idx = config.pad_token_id | ||
lora_vocab = (lora_config.lora_extra_vocab_size * | ||
(lora_config.max_loras or 1)) if lora_config else 0 | ||
self.vocab_size = config.vocab_size + lora_vocab | ||
self.org_vocab_size = config.vocab_size | ||
|
||
self.embed_tokens = VocabParallelEmbedding( | ||
self.vocab_size, | ||
config.hidden_size, | ||
org_num_embeddings=config.vocab_size, | ||
tjohnson31415 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
quant_config=quant_config, | ||
) | ||
self.embedding_multiplier = config.embedding_multiplier | ||
|
||
self.start_layer, self.end_layer, self.layers = make_layers( | ||
config.num_hidden_layers, | ||
lambda prefix: GraniteMoeSharedDecoderLayer( | ||
config, cache_config, quant_config=quant_config, prefix=prefix | ||
), | ||
prefix=f"{prefix}.layers") | ||
|
||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) | ||
|
||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: | ||
return self.embed_tokens(input_ids) | ||
|
||
def forward( | ||
self, | ||
input_ids: torch.Tensor, | ||
positions: torch.Tensor, | ||
kv_caches: List[torch.Tensor], | ||
attn_metadata: AttentionMetadata, | ||
intermediate_tensors: Optional[IntermediateTensors], | ||
inputs_embeds: Optional[torch.Tensor] = None, | ||
) -> torch.Tensor: | ||
if get_pp_group().is_first_rank: | ||
if inputs_embeds is not None: | ||
hidden_states = inputs_embeds | ||
else: | ||
hidden_states = self.get_input_embeddings(input_ids) | ||
hidden_states *= self.embedding_multiplier | ||
residual = None | ||
else: | ||
assert intermediate_tensors is not None | ||
hidden_states = intermediate_tensors["hidden_states"] | ||
residual = intermediate_tensors["residual"] | ||
for i in range(self.start_layer, self.end_layer): | ||
layer = self.layers[i] | ||
hidden_states = layer(positions, hidden_states, | ||
kv_caches[i - self.start_layer], | ||
attn_metadata) | ||
if not get_pp_group().is_last_rank: | ||
return IntermediateTensors({ | ||
"hidden_states": hidden_states, | ||
"residual": residual | ||
}) | ||
hidden_states = self.norm(hidden_states) | ||
return hidden_states | ||
|
||
|
||
class GraniteMoeSharedForCausalLM(nn.Module, SupportsLoRA): | ||
fall_back_to_pt_during_load = False | ||
|
||
packed_modules_mapping = { | ||
"qkv_proj": [ | ||
"q_proj", | ||
"k_proj", | ||
"v_proj", | ||
], | ||
} | ||
|
||
# LoRA specific attributes | ||
embedding_modules = { | ||
"embed_tokens": "input_embeddings", | ||
"lm_head": "output_embeddings", | ||
} | ||
embedding_padding_modules = ["lm_head"] | ||
|
||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): | ||
super().__init__() | ||
config = vllm_config.model_config.hf_config | ||
quant_config = vllm_config.quant_config | ||
lora_config = vllm_config.lora_config | ||
|
||
self.config = config | ||
self.lora_config = lora_config | ||
self.quant_config = quant_config | ||
|
||
self.model = GraniteMoeSharedModel(vllm_config=vllm_config, | ||
prefix=maybe_prefix( | ||
prefix, "model")) | ||
self.unpadded_vocab_size = config.vocab_size | ||
if lora_config: | ||
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size | ||
self.lm_head = ParallelLMHead( | ||
self.unpadded_vocab_size, | ||
config.hidden_size, | ||
org_num_embeddings=config.vocab_size, | ||
padding_size=DEFAULT_VOCAB_PADDING_SIZE | ||
# We need bigger padding if using lora for kernel | ||
# compatibility | ||
if not lora_config else lora_config.lora_vocab_padding_size, | ||
quant_config=quant_config, | ||
tjohnson31415 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
prefix=maybe_prefix(prefix, "lm_head")) | ||
if config.tie_word_embeddings: | ||
self.lm_head.weight = self.model.embed_tokens.weight | ||
|
||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, | ||
config.vocab_size, | ||
scale=1 / | ||
self.config.logits_scaling) | ||
|
||
self.sampler = get_sampler() | ||
|
||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: | ||
return self.model.get_input_embeddings(input_ids) | ||
|
||
def forward( | ||
self, | ||
input_ids: torch.Tensor, | ||
positions: torch.Tensor, | ||
kv_caches: List[torch.Tensor], | ||
attn_metadata: AttentionMetadata, | ||
intermediate_tensors: Optional[IntermediateTensors] = None, | ||
inputs_embeds: Optional[torch.Tensor] = None, | ||
) -> torch.Tensor: | ||
hidden_states = self.model(input_ids, positions, kv_caches, | ||
attn_metadata, intermediate_tensors, | ||
inputs_embeds) | ||
return hidden_states | ||
|
||
def compute_logits( | ||
self, hidden_states: torch.Tensor, | ||
sampling_metadata: SamplingMetadata) -> Optional[torch.Tensor]: | ||
logits = self.logits_processor(self.lm_head, hidden_states, | ||
sampling_metadata) | ||
return logits | ||
|
||
def make_empty_intermediate_tensors( | ||
self, batch_size: int, dtype: torch.dtype, | ||
device: torch.device) -> IntermediateTensors: | ||
return IntermediateTensors({ | ||
"hidden_states": | ||
torch.zeros((batch_size, self.config.hidden_size), | ||
dtype=dtype, | ||
device=device), | ||
"residual": | ||
torch.zeros((batch_size, self.config.hidden_size), | ||
dtype=dtype, | ||
device=device), | ||
}) | ||
|
||
def sample( | ||
self, | ||
logits: Optional[torch.Tensor], | ||
sampling_metadata: SamplingMetadata, | ||
) -> Optional[SamplerOutput]: | ||
next_tokens = self.sampler(logits, sampling_metadata) | ||
return next_tokens | ||
|
||
def load_weights(self, weights: Iterable[Tuple[str, | ||
torch.Tensor]]) -> Set[str]: | ||
new_weights = {} | ||
for n, p in weights: | ||
if n.endswith('.block_sparse_moe.input_linear.weight'): | ||
for e in range(p.size(0)): | ||
w1_name = n.replace( | ||
'.block_sparse_moe.input_linear.weight', | ||
f".block_sparse_moe.experts.{e}.w1.weight") | ||
w3_name = n.replace( | ||
'.block_sparse_moe.input_linear.weight', | ||
f".block_sparse_moe.experts.{e}.w3.weight") | ||
w1_param, w3_param = p[e].chunk(2, dim=0) | ||
assert w1_name not in new_weights | ||
assert w3_name not in new_weights | ||
new_weights[w1_name] = w1_param | ||
new_weights[w3_name] = w3_param | ||
elif n.endswith('.block_sparse_moe.output_linear.weight'): | ||
for e in range(p.size(0)): | ||
w2_name = n.replace( | ||
'.block_sparse_moe.output_linear.weight', | ||
f".block_sparse_moe.experts.{e}.w2.weight") | ||
w2_param = p[e] | ||
assert w2_name not in new_weights | ||
new_weights[w2_name] = w2_param | ||
elif n.endswith('.block_sparse_moe.router.layer.weight'): | ||
gate_name = n.replace('.block_sparse_moe.router.layer.weight', | ||
".block_sparse_moe.gate.weight") | ||
assert gate_name not in new_weights | ||
new_weights[gate_name] = p | ||
elif n == 'lm_head.weight' and self.config.tie_word_embeddings: | ||
pass | ||
else: | ||
new_weights[n] = p | ||
return mixtral.MixtralForCausalLM.load_weights(self, | ||
new_weights.items()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did a quick test and PP doesn't seem to work for the GraniteMoe model either.
I can look in to that as a follow-on.