Skip to content

Commit

Permalink
Update on "allow customized head_dim"
Browse files Browse the repository at this point in the history
This is for resolving the ask in this [post](https://fb.workplace.com/groups/pytorch.edge.users/permalink/1574875706716050/).

Similar change in HF: huggingface/transformers#32502

Differential Revision: [D65974454](https://our.internmc.facebook.com/intern/diff/D65974454/)

[ghstack-poisoned]
  • Loading branch information
helunwencser committed Nov 18, 2024
2 parents 0120876 + b0b44b3 commit 1c87ce3
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions examples/models/llama/llama_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,9 @@ def __post_init__(self):
hidden_dim = int(self.ffn_dim_multiplier * hidden_dim)
self.hidden_dim = find_multiple(hidden_dim, multiple_of)

if self.head_dim is None:
self.head_dim = self.dim // self.n_heads


class KVCache(nn.Module):
def __init__(
Expand Down Expand Up @@ -273,7 +276,7 @@ def __init__(self, args: ModelArgs, layer_id: int):
self.n_local_heads = self.n_heads // model_parallel_size
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = args.dim // self.n_heads if args.head_dim is None else args.head_dim
self.head_dim = args.head_dim
self.max_batch_size = args.max_batch_size
self.max_seq_len = args.max_seq_len
self.dim = args.dim
Expand Down Expand Up @@ -426,7 +429,7 @@ def __init__(self, layer_id: int, args: ModelArgs):
self.use_kv_cache = args.use_kv_cache
self.n_heads = args.n_heads
self.dim = args.dim
self.head_dim = args.dim // args.n_heads if args.head_dim is None else args.head_dim
self.head_dim = args.head_dim
self.attention = Attention(args, layer_id)
if args.moe:
self.block_sparse_moe = MOEFeedForward(args)
Expand Down Expand Up @@ -473,7 +476,7 @@ def __init__(self, params: ModelArgs):
precompute_freqs_cis, use_scaled=params.use_scaled_rope
)
freqs_cos, freqs_sin = self.precompute_freqs_cis(
params.dim // params.n_heads if params.head_dim is None else params.head_dim,
params.head_dim,
(
params.max_seq_len # Normal llama2.
if params.ffn_dim_multiplier is None
Expand Down

0 comments on commit 1c87ce3

Please sign in to comment.