Skip to content

Commit

Permalink
Merge branch 'zijiey/moe_router_2nd' into 'main'
Browse files Browse the repository at this point in the history
MoE Refactoring - Switch to mask-based routing for MoE

Closes #267

See merge request ADLR/megatron-lm!1915
  • Loading branch information
ko3n1g committed Oct 30, 2024
2 parents 2e047cf + ac0474d commit e084ab0
Show file tree
Hide file tree
Showing 7 changed files with 291 additions and 480 deletions.
159 changes: 85 additions & 74 deletions megatron/core/transformer/moe/legacy_a2a_token_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,26 @@
from typing import List, Optional, Tuple

import torch
import torch.distributed

from megatron.core import parallel_state, tensor_parallel
from megatron.core.tensor_parallel.mappings import _gather_along_first_dim_expert_parallel
from megatron.core.transformer.moe.moe_utils import permute, unpermute
from megatron.core.transformer.moe.moe_utils import (
get_capacity,
permute,
sort_chunks_by_idxs,
unpermute,
)
from megatron.core.transformer.moe.token_dispatcher import MoETokenDispatcher
from megatron.core.transformer.transformer_config import TransformerConfig


class MoEAlltoAllSEQTokenDispatcher(MoETokenDispatcher):
"""
The legacy implementation of the AlltoAll-based token dispatcher, which handles token dispatching on the sequence level instead of token level. The core of this implementation lies each device dispatching on the entire sequence, with the hidden state being partitioned.
The legacy implementation of the AlltoAll-based token dispatcher, which handles token
dispatching on the sequence level instead of token level. The core of this implementation
lies in each device dispatching on the entire sequence, with the hidden state being partitioned.
Note: This class is a replica of the MoEAlltoAllTokenDispatcher from version 0.8.
"""

Expand All @@ -34,12 +43,6 @@ def __init__(
self.num_local_experts = num_local_experts
self.num_experts = config.num_moe_experts
assert self.num_local_experts > 0, "Expected at least one expert"
if self.num_local_experts > 1:
self.expert_ids_per_ep_rank = torch.tensor(
[i % self.num_local_experts for i in range(self.num_experts)],
dtype=torch.int32,
device=torch.cuda.current_device(),
)
self.local_expert_indices = local_expert_indices
assert (
len(self.local_expert_indices) == self.num_local_experts
Expand All @@ -48,13 +51,23 @@ def __init__(
assert (
self.local_expert_indices[i] == self.local_expert_indices[i + 1] - 1
), "local_expert_indices must be continous"
self.router_topk = config.moe_router_topk
self.add_bias = config.add_bias_linear
self.ep_size = config.expert_model_parallel_size
self.tp_size = config.tensor_model_parallel_size
self.probs = None
self.input_splits = None
self.output_splits = None
self.num_global_tokens_per_local_expert = None
# [tp_size * ep_size, num_local_experts]. Represents the number of tokens sent
# to each local expert by all ranks.
self.num_global_tokens_per_local_expert_cpu = None
input_chunk_idxs = torch.arange(self.num_experts)
# [num_local_experts, ep_size]. Sort the input chunks by local experts.
self.sort_input_by_local_experts = (
input_chunk_idxs.reshape(-1, self.num_local_experts).T.ravel().tolist()
)
# [ep_size, num_local_experts]. Restore the output chunks by local experts.
self.restore_output_by_local_experts = (
input_chunk_idxs.reshape(self.num_local_experts, -1).T.ravel().tolist()
)

# Token drop and padding.
# We need to keep track of the token num if we drop tokens without padding them.
Expand All @@ -65,36 +78,48 @@ def __init__(
assert self.config.moe_expert_capacity_factor is not None
self.capacity = None

# A cuda stream synchronization is needed in self.token_permutation() in some cases,
# because there are several non-blocking DtoH data transfers called in self.preprocess().
# The synchronization happens at different points based on MoE settings as late as possible.
# Valid sync points are "before_permutation_1", "before_ep_alltoall", "before_finish", and "no_sync".
# A cuda stream synchronization is needed in self.token_permutation()
# in some cases, because there are several non-blocking DtoH data
# transfers called in self.preprocess(). The synchronization happens
# at different points based on MoE settings as late as possible.
# Valid sync points are "before_permutation_1", "before_ep_alltoall",
# "before_finish", and "no_sync".
self.cuda_sync_point = "no_sync"

def preprocess(self, indices: torch.Tensor) -> torch.Tensor:
def preprocess(self, routing_map: torch.Tensor) -> torch.Tensor:
"""
Preprocess token indices for AlltoAll communication and token permutation. This method computes the number of tokens assigned to each expert based on the input indices.
It also initializes the necessary data structures for AlltoAll communication, such as input
and output splits, and the mapping between global tokens and local experts.
Preprocess routing map for AlltoAll communication and token permutation.
This method computes the number of tokens assigned to each expert based on
the routing map. It also initializes the necessary data structures for
AlltoAll communication, such as input and output splits, and the mapping
between global tokens and local experts.
Args:
indices (torch.Tensor): Tensor of indices mapping tokens to experts.
routing_map (torch.Tensor): The mapping of tokens to experts, with shape
[num_tokens, num_experts].
Returns:
torch.Tensor: Tensor containing the number of tokens assigned to local expert.
"""
num_local_tokens_per_expert = torch.histc(
indices, bins=self.num_experts, min=0, max=self.num_experts
)
num_local_tokens_per_expert = routing_map.sum(dim=0).long()
# num_local_tokens_per_expert: [num_experts]

ep_size = self.config.expert_model_parallel_size
if self.drop_and_pad:
# probs: [num_experts, capacity]
self.capacity = self.probs.size(1)
# Drop and pad the input to capacity.
num_tokens = routing_map.size(0) * self.config.moe_router_topk
self.capacity = get_capacity(
num_tokens=num_tokens,
num_experts=self.num_experts,
capacity_factor=self.config.moe_expert_capacity_factor,
)
self.num_out_tokens = self.capacity * self.num_experts
num_tokens_per_local_expert = torch.full(
(self.num_local_experts,), self.capacity * self.ep_size, dtype=torch.long
)
self.num_global_tokens_per_local_expert_cpu = torch.full(
(self.num_experts * self.tp_size,), self.capacity, dtype=torch.long
)
return num_tokens_per_local_expert
elif self.config.moe_expert_capacity_factor is not None:
# Token drop but no pad. A synchronization is needed before the first
Expand All @@ -103,14 +128,17 @@ def preprocess(self, indices: torch.Tensor) -> torch.Tensor:
torch.device("cpu"), non_blocking=True
)
self.cuda_sync_point = "before_permutation_1"
elif ep_size > 1:
# Token dropless and enable ep. A synchronization is needed before expert parallel
# AlltoAll communication to get the `input_splits` and `output_splits` CPU values.
self.cuda_sync_point = "before_ep_alltoall"
else:
# Token dropless and no ep. A synchronization is needed before the token_permutation()
# function returns to get the `tokens_per_expert` CPU value.
self.cuda_sync_point = "before_finish"
# Dropless
self.num_out_tokens = routing_map.size(0) * self.config.moe_router_topk
if self.ep_size > 1 or self.num_local_experts > 1:
# Token dropless and enable ep. A synchronization is needed before expert parallel
# AlltoAll communication to get the `input_splits` and `output_splits` CPU values.
self.cuda_sync_point = "before_ep_alltoall"
else:
# Token dropless and no ep. A synchronization is needed to get the
# `tokens_per_expert` CPU value.
self.cuda_sync_point = "before_finish"

if ep_size > 1:
# ===================================================
Expand Down Expand Up @@ -150,25 +178,26 @@ def preprocess(self, indices: torch.Tensor) -> torch.Tensor:
)

if self.num_local_experts > 1:
# No further synchronization is needed because torch.repeat_interleave() calls stream
# synchronization internally when the `output_size` parameter is not provided.
self.cuda_sync_point = "no_sync"
self.global_input_tokens_local_experts_indices = torch.repeat_interleave(
self.expert_ids_per_ep_rank, self.num_global_tokens_per_local_expert.ravel()
self.num_global_tokens_per_local_expert_cpu = (
self.num_global_tokens_per_local_expert.view(-1, self.num_local_experts).to(
torch.device("cpu"), non_blocking=True
)
)

return num_tokens_per_local_expert

def token_permutation(
self, hidden_states: torch.Tensor, probs: torch.Tensor, indices: torch.Tensor
self, hidden_states: torch.Tensor, probs: torch.Tensor, routing_map: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Dispatch tokens to local experts using AlltoAll communication.
Args:
hidden_states (torch.Tensor): Input token embeddings.
probs (torch.Tensor): Probs of tokens assigned to experts.
indices (torch.Tensor): Indices of tokens assigned to experts.
Shape: [num_tokens, num_experts].
routing_map (torch.Tensor): Mapping of tokens assigned to experts.
Shape: [num_tokens, num_experts].
Returns:
Tuple[torch.Tensor, torch.Tensor]:
Expand All @@ -178,10 +207,11 @@ def token_permutation(
# Preprocess: Get the metadata for communication, permutation and computation operations.
self.hidden_shape = hidden_states.shape
self.probs = probs
self.routing_map = routing_map
assert probs.dim() == 2, "Expected 2D tensor for probs"
assert indices.dim() == 2, "Expected 2D tensor for indices"
assert routing_map.dim() == 2, "Expected 2D tensor for routing map"
hidden_states = hidden_states.view(-1, self.hidden_shape[-1])
tokens_per_expert = self.preprocess(indices)
tokens_per_expert = self.preprocess(routing_map)

# Perform tensor parallel AlltoAll communication
# hidden_states: [S*B/TP, H] -> [S*B, H/TP]
Expand All @@ -193,10 +223,7 @@ def token_permutation(
if self.cuda_sync_point == "before_permutation_1":
torch.cuda.current_stream().synchronize()
permutated_local_input_tokens, self.reversed_local_input_permutation_mapping = permute(
hidden_states,
indices,
num_out_tokens=self.num_out_tokens,
padded_mode=self.drop_and_pad,
hidden_states, routing_map, num_out_tokens=self.num_out_tokens
)

# Perform expert parallel AlltoAll communication
Expand All @@ -209,21 +236,13 @@ def token_permutation(
self.input_splits,
)

# Permutation 2: Sort alltoall output by local experts when num_local_experts > 1.
# Permutation 2: Sort tokens by local expert.
if self.num_local_experts > 1:
if not self.drop_and_pad:
global_input_tokens, self.reversed_global_input_permutation_mapping = permute(
global_input_tokens, self.global_input_tokens_local_experts_indices
)
else:
global_input_tokens = global_input_tokens.reshape(
self.ep_size, self.num_local_experts, self.capacity, -1
)
global_input_tokens = (
global_input_tokens.transpose(0, 1)
.reshape(self.num_local_experts * self.ep_size * self.capacity, -1)
.contiguous()
)
global_input_tokens = sort_chunks_by_idxs(
global_input_tokens,
self.num_global_tokens_per_local_expert_cpu.ravel(),
self.sort_input_by_local_experts,
)

# Perform tensor parallel AllGather on the hidden dimension to obtain the input tokens.
# global_input_tokens: [SEQL, H/TP] -> [SEQL, H]
Expand Down Expand Up @@ -260,21 +279,13 @@ def token_unpermutation(
hidden_states
)

# Unpermutation 2: expert output to AlltoAll input
# Unpermutation 2: Unsort tokens by local expert.
if self.num_local_experts > 1:
if not self.drop_and_pad:
hidden_states = unpermute(
hidden_states, self.reversed_global_input_permutation_mapping
)
else:
hidden_states = hidden_states.reshape(
self.num_local_experts, self.ep_size, self.capacity, -1
)
hidden_states = (
hidden_states.transpose(0, 1)
.reshape(self.ep_size * self.num_local_experts * self.capacity, -1)
.contiguous()
)
hidden_states = sort_chunks_by_idxs(
hidden_states,
self.num_global_tokens_per_local_expert_cpu.T.ravel(),
self.restore_output_by_local_experts,
)

# Perform expert parallel AlltoAll communication
# hidden_states: [SEQL, H] -> [SEQL, H/TP]
Expand All @@ -290,8 +301,8 @@ def token_unpermutation(
permutated_local_input_tokens,
self.reversed_local_input_permutation_mapping,
probs=self.probs,
padded_mode=self.drop_and_pad,
restore_shape=self.hidden_shape_before_permute,
routing_map=self.routing_map,
)

# Perform tensor parallel AlltoAll communication
Expand Down
4 changes: 2 additions & 2 deletions megatron/core/transformer/moe/moe_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,9 @@ def forward(self, hidden_states: torch.Tensor):

# process MoE
def custom_forward(hidden_states):
probs, indices = self.router(hidden_states)
probs, routing_map = self.router(hidden_states)
(dispatched_input, tokens_per_expert) = self.token_dispatcher.token_permutation(
hidden_states, probs, indices
hidden_states, probs, routing_map
)
expert_output, mlp_bias = self.experts(dispatched_input, tokens_per_expert)
output, mlp_bias = self.token_dispatcher.token_unpermutation(expert_output, mlp_bias)
Expand Down
Loading

0 comments on commit e084ab0

Please sign in to comment.