diff --git a/src/transformers/models/llama/configuration_llama.py b/src/transformers/models/llama/configuration_llama.py index 710809093f38..e8d22b781d52 100644 --- a/src/transformers/models/llama/configuration_llama.py +++ b/src/transformers/models/llama/configuration_llama.py @@ -45,6 +45,8 @@ class LlamaConfig(PretrainedConfig): Number of hidden layers in the Transformer decoder. num_attention_heads (`int`, *optional*, defaults to 32): Number of attention heads for each attention layer in the Transformer decoder. + head_dim (`int`, *optional*, defaults to `hidden_size // num_attention_heads`): + The attention head dimension. num_key_value_heads (`int`, *optional*): This is the number of key_value heads that should be used to implement Grouped Query Attention. If `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if @@ -147,6 +149,7 @@ def __init__( intermediate_size=11008, num_hidden_layers=32, num_attention_heads=32, + head_dim=None, num_key_value_heads=None, hidden_act="silu", max_position_embeddings=2048, @@ -171,6 +174,7 @@ def __init__( self.intermediate_size = intermediate_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads + self.head_dim = head_dim or hidden_size // num_attention_heads # for backward compatibility if num_key_value_heads is None: diff --git a/src/transformers/models/llama/modeling_flax_llama.py b/src/transformers/models/llama/modeling_flax_llama.py index 1c9f1c4adc3e..4ce25d33ce16 100644 --- a/src/transformers/models/llama/modeling_flax_llama.py +++ b/src/transformers/models/llama/modeling_flax_llama.py @@ -198,7 +198,7 @@ def setup(self): config = self.config self.embed_dim = config.hidden_size self.num_heads = config.num_attention_heads - self.head_dim = self.embed_dim // self.num_heads + self.head_dim = config.head_dim self.num_key_value_heads = config.num_key_value_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.attention_softmax_in_fp32 = self.dtype is not jnp.float32 @@ -214,12 +214,6 @@ def setup(self): self.k_proj = dense(self.num_key_value_heads * self.head_dim) self.v_proj = dense(self.num_key_value_heads * self.head_dim) self.o_proj = dense(self.embed_dim) - if (self.head_dim * self.num_heads) != self.embed_dim: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.embed_dim}" - f" and `num_heads`: {self.num_heads})." - ) - self.causal_mask = make_causal_mask(jnp.ones((1, config.max_position_embeddings), dtype="bool"), dtype="bool") self.rotary_emb = FlaxLlamaRotaryEmbedding(config, dtype=self.dtype) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 4a0887629cc6..e7020af42003 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -340,23 +340,17 @@ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): self.attention_dropout = config.attention_dropout self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads + self.head_dim = config.head_dim self.num_key_value_heads = config.num_key_value_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.max_position_embeddings = config.max_position_embeddings self.rope_theta = config.rope_theta self.is_causal = True - if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) # TODO (joao): remove in v4.45 (RoPE is computed in the model, not in the decoder layers) self.rotary_emb = LlamaRotaryEmbedding(config=self.config)