Skip to content

Commit

Permalink
Add default TP plan for all models with backend support (#35870)
Browse files Browse the repository at this point in the history
* Add some tp plans!

* More tp plans!

* Add it in the comment

* style

* Update configuration_mixtral.py

* Update configuration_phi.py

* update the layout according to special archs

* fix mixtral

* style

* trigger CIs

* trigger CIs

* CIs

* olmo2

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
  • Loading branch information
Cyrilvallez and ArthurZucker authored Jan 28, 2025
1 parent 96625d8 commit 3613f56
Show file tree
Hide file tree
Showing 17 changed files with 134 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1294,7 +1294,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix

# This flag signal that the model can be used as an efficient backend in TGI and vLLM
# In practice, it means that they support attention interface functions, fully pass the kwargs
# through all modules up to the Attention layer, and can slice logits with Tensor
# through all modules up to the Attention layer, can slice logits with Tensor, and have a default TP plan
_supports_attention_backend = False

@property
Expand Down
9 changes: 9 additions & 0 deletions src/transformers/models/cohere/configuration_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,15 @@ class CohereConfig(PretrainedConfig):

model_type = "cohere"
keys_to_ignore_at_inference = ["past_key_values"]
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}

def __init__(
self,
Expand Down
9 changes: 9 additions & 0 deletions src/transformers/models/cohere2/configuration_cohere2.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,15 @@ class Cohere2Config(PretrainedConfig):

model_type = "cohere2"
keys_to_ignore_at_inference = ["past_key_values"]
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}

def __init__(
self,
Expand Down
9 changes: 9 additions & 0 deletions src/transformers/models/cohere2/modular_cohere2.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,15 @@ class Cohere2Config(PretrainedConfig):

model_type = "cohere2"
keys_to_ignore_at_inference = ["past_key_values"]
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}

def __init__(
self,
Expand Down
9 changes: 9 additions & 0 deletions src/transformers/models/gemma/configuration_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,15 @@ class GemmaConfig(PretrainedConfig):

model_type = "gemma"
keys_to_ignore_at_inference = ["past_key_values"]
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}

def __init__(
self,
Expand Down
9 changes: 9 additions & 0 deletions src/transformers/models/gemma/modular_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,15 @@ class GemmaConfig(PretrainedConfig):

model_type = "gemma"
keys_to_ignore_at_inference = ["past_key_values"]
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}

def __init__(
self,
Expand Down
9 changes: 9 additions & 0 deletions src/transformers/models/gemma2/configuration_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,15 @@ class Gemma2Config(PretrainedConfig):

model_type = "gemma2"
keys_to_ignore_at_inference = ["past_key_values"]
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}

def __init__(
self,
Expand Down
9 changes: 9 additions & 0 deletions src/transformers/models/gemma2/modular_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,15 @@ class Gemma2Config(PretrainedConfig):

model_type = "gemma2"
keys_to_ignore_at_inference = ["past_key_values"]
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}

def __init__(
self,
Expand Down
8 changes: 8 additions & 0 deletions src/transformers/models/glm/configuration_glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,14 @@ class GlmConfig(PretrainedConfig):

model_type = "glm"
keys_to_ignore_at_inference = ["past_key_values"]
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_up_proj": "colwise_rep", # we need to replicate here due to the `chunk` operation
"layers.*.mlp.down_proj": "rowwise_rep", # we need to replicate here due to the `chunk` operation
}

def __init__(
self,
Expand Down
9 changes: 9 additions & 0 deletions src/transformers/models/helium/configuration_helium.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,15 @@ class HeliumConfig(PretrainedConfig):

model_type = "helium"
keys_to_ignore_at_inference = ["past_key_values"]
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}

def __init__(
self,
Expand Down
10 changes: 10 additions & 0 deletions src/transformers/models/mixtral/configuration_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,16 @@ class MixtralConfig(PretrainedConfig):

model_type = "mixtral"
keys_to_ignore_at_inference = ["past_key_values"]
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.block_sparse_moe.gate": "colwise_rep", # we need to replicate here to correctly route experts
"layers.*.block_sparse_moe.experts.*.w1": "colwise",
"layers.*.block_sparse_moe.experts.*.w2": "rowwise",
"layers.*.block_sparse_moe.experts.*.w3": "colwise",
}

def __init__(
self,
Expand Down
9 changes: 9 additions & 0 deletions src/transformers/models/olmo/configuration_olmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,15 @@ class OlmoConfig(PretrainedConfig):

model_type = "olmo"
keys_to_ignore_at_inference = ["past_key_values"]
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}

def __init__(
self,
Expand Down
9 changes: 9 additions & 0 deletions src/transformers/models/olmo2/configuration_olmo2.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,15 @@ class Olmo2Config(PretrainedConfig):

model_type = "olmo2"
keys_to_ignore_at_inference = ["past_key_values"]
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
"layers.*.self_attn.k_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
"layers.*.self_attn.v_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
"layers.*.self_attn.o_proj": "rowwise_rep", # we need to replicate here due to the added norm on q and k
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}

def __init__(
self,
Expand Down
9 changes: 9 additions & 0 deletions src/transformers/models/olmo2/modular_olmo2.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,15 @@ class Olmo2Config(OlmoConfig):
"""

model_type = "olmo2"
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
"layers.*.self_attn.k_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
"layers.*.self_attn.v_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
"layers.*.self_attn.o_proj": "rowwise_rep", # we need to replicate here due to the added norm on q and k
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}

def __init__(
self,
Expand Down
8 changes: 8 additions & 0 deletions src/transformers/models/phi/configuration_phi.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,14 @@ class PhiConfig(PretrainedConfig):

model_type = "phi"
keys_to_ignore_at_inference = ["past_key_values"]
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.dense": "rowwise",
"layers.*.mlp.fc1": "colwise",
"layers.*.mlp.fc2": "rowwise",
}

def __init__(
self,
Expand Down
6 changes: 6 additions & 0 deletions src/transformers/models/phi3/configuration_phi3.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,12 @@ class Phi3Config(PretrainedConfig):

model_type = "phi3"
keys_to_ignore_at_inference = ["past_key_values"]
base_model_tp_plan = {
"layers.*.self_attn.qkv_proj": "colwise_rep", # we need to replicate here due to the slicing of qkv
"layers.*.self_attn.o_proj": "rowwise_rep", # we need to replicate here due to the slicing of qkv
"layers.*.mlp.gate_up_proj": "colwise_rep", # we need to replicate here due to the `chunk` operation
"layers.*.mlp.down_proj": "rowwise_rep", # we need to replicate here due to the `chunk` operation
}

def __init__(
self,
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,5 +358,7 @@ def translate_to_torch_parallel_style(style: str):
return RowwiseParallel()
elif style == "colwise_rep":
return ColwiseParallel(output_layouts=Replicate())
elif style == "rowwise_rep":
return RowwiseParallel(input_layouts=Replicate())
else:
raise ValueError(f"Unsupported parallel style value: {style}")

0 comments on commit 3613f56

Please sign in to comment.