Skip to content

Commit

Permalink
allow customized head_dim
Browse files Browse the repository at this point in the history
This is for resolving the ask in this [post](https://fb.workplace.com/groups/pytorch.edge.users/permalink/1574875706716050/).

Similar change in HF: huggingface/transformers#32502

Differential Revision: [D65974454](https://our.internmc.facebook.com/intern/diff/D65974454/)

[ghstack-poisoned]
  • Loading branch information
helunwencser committed Nov 14, 2024
1 parent ecdc007 commit c083d85
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion examples/models/llama/llama_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ class ModelArgs:
n_kv_heads: Optional[int] = None
vocab_size: int = -1 # defined later by tokenizer
hidden_dim: Optional[int] = None
head_dim: Optional[int] = None
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
ffn_dim_multiplier: Optional[float] = None
norm_eps: float = 1e-5
Expand Down Expand Up @@ -272,7 +273,7 @@ def __init__(self, args: ModelArgs, layer_id: int):
self.n_local_heads = self.n_heads // model_parallel_size
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = args.dim // self.n_heads
self.head_dim = args.dim // self.n_heads if args.head_dim is None else args.head_dim
self.max_batch_size = args.max_batch_size
self.max_seq_len = args.max_seq_len
self.dim = args.dim
Expand Down

0 comments on commit c083d85

Please sign in to comment.