Skip to content

Commit

Permalink
allow customized head_dim
Browse files Browse the repository at this point in the history
Pull Request resolved: #6872

This is for resolving the ask in this [post](https://fb.workplace.com/groups/pytorch.edge.users/permalink/1574875706716050/).

Similar change in HF: huggingface/transformers#32502

Differential Revision: [D65974454](https://our.internmc.facebook.com/intern/diff/D65974454/)
ghstack-source-id: 254171929
  • Loading branch information
helunwencser committed Nov 18, 2024
1 parent e95f171 commit 5e87ad3
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions examples/models/llama/llama_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ class ModelArgs:
n_kv_heads: Optional[int] = None
vocab_size: int = -1 # defined later by tokenizer
hidden_dim: Optional[int] = None
head_dim: Optional[int] = None
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
ffn_dim_multiplier: Optional[float] = None
norm_eps: float = 1e-5
Expand Down Expand Up @@ -272,7 +273,7 @@ def __init__(self, args: ModelArgs, layer_id: int):
self.n_local_heads = self.n_heads // model_parallel_size
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = args.dim // self.n_heads
self.head_dim = args.dim // self.n_heads if args.head_dim is None else args.head_dim
self.max_batch_size = args.max_batch_size
self.max_seq_len = args.max_seq_len
self.dim = args.dim
Expand Down Expand Up @@ -304,7 +305,7 @@ def __init__(self, args: ModelArgs, layer_id: int):
)
self.SDPA = SDPA(
kv_cache=self.kv_cache,
dim=self.dim,
dim=self.n_local_heads * self.head_dim,
head_dim=self.head_dim,
n_rep=self.n_rep,
max_seq_len=self.max_seq_len,
Expand Down Expand Up @@ -425,7 +426,7 @@ def __init__(self, layer_id: int, args: ModelArgs):
self.use_kv_cache = args.use_kv_cache
self.n_heads = args.n_heads
self.dim = args.dim
self.head_dim = args.dim // args.n_heads
self.head_dim = args.dim // args.n_heads if args.head_dim is None else args.head_dim
self.attention = Attention(args, layer_id)
if args.moe:
self.block_sparse_moe = MOEFeedForward(args)
Expand Down Expand Up @@ -472,7 +473,7 @@ def __init__(self, params: ModelArgs):
precompute_freqs_cis, use_scaled=params.use_scaled_rope
)
freqs_cos, freqs_sin = self.precompute_freqs_cis(
params.dim // params.n_heads,
params.dim // params.n_heads if params.head_dim is None else params.head_dim,
(
params.max_seq_len # Normal llama2.
if params.ffn_dim_multiplier is None
Expand Down

0 comments on commit 5e87ad3

Please sign in to comment.