Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow head_dim for llama like gemma or mistral #32847

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/transformers/models/llama/configuration_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ class LlamaConfig(PretrainedConfig):
Number of hidden layers in the Transformer decoder.
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the Transformer decoder.
head_dim (`int`, *optional*, defaults to `hidden_size // num_attention_heads`):
The attention head dimension.
num_key_value_heads (`int`, *optional*):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
Expand Down Expand Up @@ -147,6 +149,7 @@ def __init__(
intermediate_size=11008,
num_hidden_layers=32,
num_attention_heads=32,
head_dim=None,
num_key_value_heads=None,
hidden_act="silu",
max_position_embeddings=2048,
Expand All @@ -171,6 +174,7 @@ def __init__(
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.head_dim = head_dim or hidden_size // num_attention_heads

# for backward compatibility
if num_key_value_heads is None:
Expand Down
8 changes: 1 addition & 7 deletions src/transformers/models/llama/modeling_flax_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def setup(self):
config = self.config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
self.head_dim = config.head_dim
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.attention_softmax_in_fp32 = self.dtype is not jnp.float32
Expand All @@ -214,12 +214,6 @@ def setup(self):
self.k_proj = dense(self.num_key_value_heads * self.head_dim)
self.v_proj = dense(self.num_key_value_heads * self.head_dim)
self.o_proj = dense(self.embed_dim)
if (self.head_dim * self.num_heads) != self.embed_dim:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.embed_dim}"
f" and `num_heads`: {self.num_heads})."
)

self.causal_mask = make_causal_mask(jnp.ones((1, config.max_position_embeddings), dtype="bool"), dtype="bool")
self.rotary_emb = FlaxLlamaRotaryEmbedding(config, dtype=self.dtype)

Expand Down
10 changes: 2 additions & 8 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,23 +340,17 @@ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
self.attention_dropout = config.attention_dropout
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.head_dim = config.head_dim
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.is_causal = True

if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads})."
)

self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)

# TODO (joao): remove in v4.45 (RoPE is computed in the model, not in the decoder layers)
self.rotary_emb = LlamaRotaryEmbedding(config=self.config)
Expand Down