Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add default TP plan for all models with backend support #35870

Merged
merged 14 commits into from
Jan 28, 2025
2 changes: 1 addition & 1 deletion src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1294,7 +1294,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix

# This flag signal that the model can be used as an efficient backend in TGI and vLLM
# In practice, it means that they support attention interface functions, fully pass the kwargs
# through all modules up to the Attention layer, and can slice logits with Tensor
# through all modules up to the Attention layer, can slice logits with Tensor, and have a default TP plan
_supports_attention_backend = False

@property
Expand Down
9 changes: 9 additions & 0 deletions src/transformers/models/cohere/configuration_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,15 @@ class CohereConfig(PretrainedConfig):

model_type = "cohere"
keys_to_ignore_at_inference = ["past_key_values"]
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}

def __init__(
self,
Expand Down
9 changes: 9 additions & 0 deletions src/transformers/models/cohere2/configuration_cohere2.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,15 @@ class Cohere2Config(PretrainedConfig):

model_type = "cohere2"
keys_to_ignore_at_inference = ["past_key_values"]
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}

def __init__(
self,
Expand Down
9 changes: 9 additions & 0 deletions src/transformers/models/cohere2/modular_cohere2.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,15 @@ class Cohere2Config(PretrainedConfig):

model_type = "cohere2"
keys_to_ignore_at_inference = ["past_key_values"]
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}

def __init__(
self,
Expand Down
9 changes: 9 additions & 0 deletions src/transformers/models/gemma/configuration_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,15 @@ class GemmaConfig(PretrainedConfig):

model_type = "gemma"
keys_to_ignore_at_inference = ["past_key_values"]
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}

def __init__(
self,
Expand Down
9 changes: 9 additions & 0 deletions src/transformers/models/gemma/modular_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,15 @@ class GemmaConfig(PretrainedConfig):

model_type = "gemma"
keys_to_ignore_at_inference = ["past_key_values"]
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}

def __init__(
self,
Expand Down
9 changes: 9 additions & 0 deletions src/transformers/models/gemma2/configuration_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,15 @@ class Gemma2Config(PretrainedConfig):

model_type = "gemma2"
keys_to_ignore_at_inference = ["past_key_values"]
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}

def __init__(
self,
Expand Down
9 changes: 9 additions & 0 deletions src/transformers/models/gemma2/modular_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,15 @@ class Gemma2Config(PretrainedConfig):

model_type = "gemma2"
keys_to_ignore_at_inference = ["past_key_values"]
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}

def __init__(
self,
Expand Down
8 changes: 8 additions & 0 deletions src/transformers/models/glm/configuration_glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,14 @@ class GlmConfig(PretrainedConfig):

model_type = "glm"
keys_to_ignore_at_inference = ["past_key_values"]
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_up_proj": "colwise_rep", # we need to replicate here due to the `chunk` operation
"layers.*.mlp.down_proj": "rowwise_rep", # we need to replicate here due to the `chunk` operation
}

def __init__(
self,
Expand Down
9 changes: 9 additions & 0 deletions src/transformers/models/helium/configuration_helium.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,15 @@ class HeliumConfig(PretrainedConfig):

model_type = "helium"
keys_to_ignore_at_inference = ["past_key_values"]
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}

def __init__(
self,
Expand Down
10 changes: 10 additions & 0 deletions src/transformers/models/mixtral/configuration_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,16 @@ class MixtralConfig(PretrainedConfig):

model_type = "mixtral"
keys_to_ignore_at_inference = ["past_key_values"]
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.block_sparse_moe.gate": "colwise_rep", # we need to replicate here to correctly route experts
"layers.*.block_sparse_moe.experts.*.w1": "colwise",
"layers.*.block_sparse_moe.experts.*.w2": "rowwise",
"layers.*.block_sparse_moe.experts.*.w3": "colwise",
}

def __init__(
self,
Expand Down
9 changes: 9 additions & 0 deletions src/transformers/models/olmo/configuration_olmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,15 @@ class OlmoConfig(PretrainedConfig):

model_type = "olmo"
keys_to_ignore_at_inference = ["past_key_values"]
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}

def __init__(
self,
Expand Down
9 changes: 9 additions & 0 deletions src/transformers/models/olmo2/configuration_olmo2.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,15 @@ class Olmo2Config(PretrainedConfig):

model_type = "olmo2"
keys_to_ignore_at_inference = ["past_key_values"]
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
"layers.*.self_attn.k_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
"layers.*.self_attn.v_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
"layers.*.self_attn.o_proj": "rowwise_rep", # we need to replicate here due to the added norm on q and k
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}

def __init__(
self,
Expand Down
9 changes: 9 additions & 0 deletions src/transformers/models/olmo2/modular_olmo2.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,15 @@ class Olmo2Config(OlmoConfig):
"""

model_type = "olmo2"
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
"layers.*.self_attn.k_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
"layers.*.self_attn.v_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
"layers.*.self_attn.o_proj": "rowwise_rep", # we need to replicate here due to the added norm on q and k
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}

def __init__(
self,
Expand Down
8 changes: 8 additions & 0 deletions src/transformers/models/phi/configuration_phi.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,14 @@ class PhiConfig(PretrainedConfig):

model_type = "phi"
keys_to_ignore_at_inference = ["past_key_values"]
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.dense": "rowwise",
"layers.*.mlp.fc1": "colwise",
"layers.*.mlp.fc2": "rowwise",
}

def __init__(
self,
Expand Down
6 changes: 6 additions & 0 deletions src/transformers/models/phi3/configuration_phi3.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,12 @@ class Phi3Config(PretrainedConfig):

model_type = "phi3"
keys_to_ignore_at_inference = ["past_key_values"]
base_model_tp_plan = {
"layers.*.self_attn.qkv_proj": "colwise_rep", # we need to replicate here due to the slicing of qkv
"layers.*.self_attn.o_proj": "rowwise_rep", # we need to replicate here due to the slicing of qkv
"layers.*.mlp.gate_up_proj": "colwise_rep", # we need to replicate here due to the `chunk` operation
"layers.*.mlp.down_proj": "rowwise_rep", # we need to replicate here due to the `chunk` operation
}

def __init__(
self,
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,5 +358,7 @@ def translate_to_torch_parallel_style(style: str):
return RowwiseParallel()
elif style == "colwise_rep":
return ColwiseParallel(output_layouts=Replicate())
elif style == "rowwise_rep":
return RowwiseParallel(input_layouts=Replicate())
else:
raise ValueError(f"Unsupported parallel style value: {style}")
Loading