Skip to content

Commit

Permalink
[Model] Add support for falcon-11B (#5069)
Browse files Browse the repository at this point in the history
  • Loading branch information
Isotr0py authored May 27, 2024
1 parent fbdb7b3 commit 890aa93
Showing 1 changed file with 40 additions and 15 deletions.
55 changes: 40 additions & 15 deletions vllm/model_executor/models/falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
Expand Down Expand Up @@ -246,18 +246,26 @@ def __init__(
self.mlp = FalconMLP(config, quant_config)
self.config = config

if config.new_decoder_architecture:
# The layer norm before self-attention
self.ln_attn = LayerNorm(hidden_size,
eps=config.layer_norm_epsilon)
# The layer norm before the MLP
self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
else:
if (config.num_ln_in_parallel_attn is None
and config.new_decoder_architecture):
config.num_ln_in_parallel_attn = 2

if not config.parallel_attn:
self.post_attention_layernorm = LayerNorm(
hidden_size, eps=config.layer_norm_epsilon)
self.input_layernorm = LayerNorm(hidden_size,
eps=config.layer_norm_epsilon)
if not config.parallel_attn:
self.post_attention_layernorm = LayerNorm(
hidden_size, eps=config.layer_norm_epsilon)
else:
if config.num_ln_in_parallel_attn == 2:
# The layer norm before self-attention
self.ln_attn = LayerNorm(hidden_size,
eps=config.layer_norm_epsilon)
# The layer norm before the MLP
self.ln_mlp = LayerNorm(hidden_size,
eps=config.layer_norm_epsilon)
else:
self.input_layernorm = LayerNorm(hidden_size,
eps=config.layer_norm_epsilon)

self.reduce_row_parallel_results = not (config.new_decoder_architecture
or config.parallel_attn)
Expand All @@ -271,7 +279,7 @@ def forward(
) -> torch.Tensor:
residual = hidden_states

if self.config.new_decoder_architecture:
if self.config.num_ln_in_parallel_attn == 2:
attention_layernorm_out = self.ln_attn(hidden_states)
mlp_layernorm_out = self.ln_mlp(hidden_states)
else:
Expand All @@ -294,6 +302,10 @@ def forward(
residual += attention_output
mlp_layernorm_out = self.post_attention_layernorm(residual)

if (self.config.new_decoder_architecture and self.config.parallel_attn
and self.config.num_ln_in_parallel_attn == 1):
mlp_layernorm_out = attention_layernorm_out

# MLP.
mlp_output, mlp_bias = self.mlp(mlp_layernorm_out)
if self.reduce_row_parallel_results and mlp_bias is not None:
Expand Down Expand Up @@ -375,7 +387,20 @@ def __init__(
self.config = config
self.quant_config = quant_config
self.transformer = FalconModel(config, cache_config, quant_config)
self.lm_head_weight = self.transformer.word_embeddings.weight
# only Falcon-11B doesn't share lm_head weight with word embeddings
# and previous Falcon model doesn't have tie_word_embeddings config
# so we set tie_word_embeddings to True by default
self.tie_word_embeddings = (config.tie_word_embeddings
if config.tie_word_embeddings is not None
else True)
if self.tie_word_embeddings:
self.lm_head_weight = self.transformer.word_embeddings.weight
else:
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
)
self.lm_head_weight = self.lm_head.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()

Expand Down Expand Up @@ -419,8 +444,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
num_query_heads_per_kv_head = total_num_heads // total_num_kv_heads
params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in weights:
if name == "lm_head.weight":
# Falcon uses tied embeddings.
if name == "lm_head.weight" and self.tie_word_embeddings:
# Falcon uses tied embeddings except Falcon-11b.
continue
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
Expand Down

0 comments on commit 890aa93

Please sign in to comment.