diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py index ba707adb03dfe..9618652f70d23 100644 --- a/vllm/model_executor/models/falcon.py +++ b/vllm/model_executor/models/falcon.py @@ -41,7 +41,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding) + ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import SamplerOutput @@ -246,18 +246,26 @@ def __init__( self.mlp = FalconMLP(config, quant_config) self.config = config - if config.new_decoder_architecture: - # The layer norm before self-attention - self.ln_attn = LayerNorm(hidden_size, - eps=config.layer_norm_epsilon) - # The layer norm before the MLP - self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - else: + if (config.num_ln_in_parallel_attn is None + and config.new_decoder_architecture): + config.num_ln_in_parallel_attn = 2 + + if not config.parallel_attn: + self.post_attention_layernorm = LayerNorm( + hidden_size, eps=config.layer_norm_epsilon) self.input_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - if not config.parallel_attn: - self.post_attention_layernorm = LayerNorm( - hidden_size, eps=config.layer_norm_epsilon) + else: + if config.num_ln_in_parallel_attn == 2: + # The layer norm before self-attention + self.ln_attn = LayerNorm(hidden_size, + eps=config.layer_norm_epsilon) + # The layer norm before the MLP + self.ln_mlp = LayerNorm(hidden_size, + eps=config.layer_norm_epsilon) + else: + self.input_layernorm = LayerNorm(hidden_size, + eps=config.layer_norm_epsilon) self.reduce_row_parallel_results = not (config.new_decoder_architecture or config.parallel_attn) @@ -271,7 +279,7 @@ def forward( ) -> torch.Tensor: residual = hidden_states - if self.config.new_decoder_architecture: + if self.config.num_ln_in_parallel_attn == 2: attention_layernorm_out = self.ln_attn(hidden_states) mlp_layernorm_out = self.ln_mlp(hidden_states) else: @@ -294,6 +302,10 @@ def forward( residual += attention_output mlp_layernorm_out = self.post_attention_layernorm(residual) + if (self.config.new_decoder_architecture and self.config.parallel_attn + and self.config.num_ln_in_parallel_attn == 1): + mlp_layernorm_out = attention_layernorm_out + # MLP. mlp_output, mlp_bias = self.mlp(mlp_layernorm_out) if self.reduce_row_parallel_results and mlp_bias is not None: @@ -375,7 +387,20 @@ def __init__( self.config = config self.quant_config = quant_config self.transformer = FalconModel(config, cache_config, quant_config) - self.lm_head_weight = self.transformer.word_embeddings.weight + # only Falcon-11B doesn't share lm_head weight with word embeddings + # and previous Falcon model doesn't have tie_word_embeddings config + # so we set tie_word_embeddings to True by default + self.tie_word_embeddings = (config.tie_word_embeddings + if config.tie_word_embeddings is not None + else True) + if self.tie_word_embeddings: + self.lm_head_weight = self.transformer.word_embeddings.weight + else: + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + ) + self.lm_head_weight = self.lm_head.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() @@ -419,8 +444,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): num_query_heads_per_kv_head = total_num_heads // total_num_kv_heads params_dict = dict(self.named_parameters(remove_duplicate=False)) for name, loaded_weight in weights: - if name == "lm_head.weight": - # Falcon uses tied embeddings. + if name == "lm_head.weight" and self.tie_word_embeddings: + # Falcon uses tied embeddings except Falcon-11b. continue # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: