diff --git a/examples/models/llama/llama_transformer.py b/examples/models/llama/llama_transformer.py index 20b8b1e30d..76cd218d65 100644 --- a/examples/models/llama/llama_transformer.py +++ b/examples/models/llama/llama_transformer.py @@ -85,6 +85,7 @@ class ModelArgs: n_kv_heads: Optional[int] = None vocab_size: int = -1 # defined later by tokenizer hidden_dim: Optional[int] = None + head_dim: Optional[int] = None multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 ffn_dim_multiplier: Optional[float] = None norm_eps: float = 1e-5 @@ -272,7 +273,7 @@ def __init__(self, args: ModelArgs, layer_id: int): self.n_local_heads = self.n_heads // model_parallel_size self.n_local_kv_heads = self.n_kv_heads // model_parallel_size self.n_rep = self.n_local_heads // self.n_local_kv_heads - self.head_dim = args.dim // self.n_heads + self.head_dim = args.dim // self.n_heads if args.head_dim is None else args.head_dim self.max_batch_size = args.max_batch_size self.max_seq_len = args.max_seq_len self.dim = args.dim