Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Correction to TP logic for Mamba Mixer 2 when Num Groups not divisible by TP Size #13660

Merged
merged 2 commits into from
Feb 22, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 21 additions & 7 deletions vllm/model_executor/layers/mamba/mamba_mixer2.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,8 @@ def extra_groups_for_head_shards(ngroups: int, tp_size: int):
if ngroups % tp_size == 0:
return 0

return tp_size - ngroups % tp_size
# for n_groups == 1, this is exactly tp_size - n_groups
return tp_size - ngroups


def mamba_v2_sharded_weight_loader(
Expand All @@ -153,7 +154,7 @@ def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
boundary, loaded_boundary = 0, 0

# - iterate over the shard specs
for full_dim, extra, ratio in shard_spec:
for full_dim, extra, duplicate_groups in shard_spec:
# - full dim is the model dim (before TP).
# - extra > 0, means there is expected overall increase
# of dimensions. This is so because of replication.
Expand All @@ -167,7 +168,12 @@ def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
# - compute the rank into the loaded shard.
# - if there is replication, different TP shards will
# take from the same rank.
rank = tp_rank // ratio
if duplicate_groups:
# NOTE: currently we only support duplication
# in the case where num_groups == 1
rank = 0
else:
rank = tp_rank

# - leftmost boundary index into loaded weight.
loaded_skip = rank * shard_size
Expand Down Expand Up @@ -233,12 +239,21 @@ def __init__(self,
# - HOWEVER IF, world_size DOES NOT divide groups, then we need
# to allocate extra space in the shard, such that groups
# may be replicated to follow the head shard.
# - NOTE: currently for the world size DOES NOT divide groups
# case, we only support the case when n_groups == 1
self.tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()

assert num_heads % self.tp_size == 0, \
"Tensor parallel world size must divide num heads."


assert (n_groups % self.tp_size) == 0 or n_groups == 1, \
(
"If tensor parallel world size does not divide num_heads, "
"then num_groups must equal 1."
)

self.ssm_state_size = ssm_state_size
self.activation = activation

Expand Down Expand Up @@ -284,11 +299,10 @@ def __init__(self,
self.n_groups * self.ssm_state_size, # expected model size
(self.n_groups - n_groups) *
self.ssm_state_size, # extra dims assigned
self.num_heads //
n_groups, # ratio for mapping back to original group
n_groups == 1, # if there was only one group
)
intermediate_settings = (intermediate_size, 0, 1)
head_setings = (self.num_heads, 0, 1)
intermediate_settings = (intermediate_size, 0, False)
head_setings = (self.num_heads, 0, False)

# - the weight already has a "weight_loader" attribute
# which set_weight_attrs will raise if we do not
Expand Down