diff --git a/sharktank/sharktank/layers/ffn_moe_block.py b/sharktank/sharktank/layers/ffn_moe_block.py index 471387853..626cad90d 100644 --- a/sharktank/sharktank/layers/ffn_moe_block.py +++ b/sharktank/sharktank/layers/ffn_moe_block.py @@ -31,32 +31,37 @@ def __init__( self.ffn_up = theta.tensor("ffn_up_exps", "weight") self.ffn_down = theta.tensor("ffn_down_exps", "weight") - # def pre_matmul_gather(self, inputs, weights, experts): - # inputs = inputs[:,:] - # weights = weights[experts.reshape(-1), :, :] - # matmul = torch.einsum("mk,mnk->mn", inputs, weights) - # return matmul def pre_matmul_gather(self, inputs, weights, experts): - matmul = torch.einsum("mk,bnk->bmn", inputs, weights) - - # Post mix the experts - oh = ( - torch.nn.functional.one_hot(experts.reshape(-1), num_classes=8) - .transpose(0, 1) - .to(torch.float32) - ) - output = torch.einsum("bm,bmn->mn", oh, matmul) - return output + inputs = inputs[:, :] + weights = weights[experts, :, :] + matmul = torch.einsum("mk,menk->men", inputs, weights) + return matmul + + def bigger_mmg(self, inputs, weights, experts): + inputs = inputs[:, :] + weights = weights[experts, :, :] + matmul = torch.einsum("mek,menk->men", inputs, weights) + return matmul + + # def pre_matmul_gather(self, inputs, weights, experts): + # matmul = torch.einsum("mk,bnk->bmn", inputs, weights) + # + # # Post mix the experts + # oh = torch.nn.functional.one_hot(experts.reshape(-1), num_classes=8).transpose(0, 1).to(torch.float32) + # output = torch.einsum("bm,bmn->mn", oh, matmul) + # return output def forward( self, h: torch.Tensor, experts: torch.Tensor, + expert_gate: torch.Tensor, ): ffn_gate = F.silu(self.pre_matmul_gather(h, self.ffn_gate.as_torch(), experts)) ffn_up = self.pre_matmul_gather(h, self.ffn_up, experts) - ffn_down = self.pre_matmul_gather(ffn_gate * ffn_up, self.ffn_down, experts) - return ffn_down + ffn_down = self.bigger_mmg(ffn_gate * ffn_up, self.ffn_down, experts) + ffn_down = torch.einsum("me,men->men", expert_gate, ffn_down) + return torch.sum(ffn_down, dim=1) class FFNMOE(ThetaLayer): diff --git a/sharktank/sharktank/layers/mixture_of_experts_block.py b/sharktank/sharktank/layers/mixture_of_experts_block.py index 49eaead95..317aa92b5 100644 --- a/sharktank/sharktank/layers/mixture_of_experts_block.py +++ b/sharktank/sharktank/layers/mixture_of_experts_block.py @@ -167,14 +167,14 @@ def forward( router_weights = F.softmax(router_logits, dim=1, dtype=torch.float) # Select top k experts from router weights - router_weights, top_k_experts = torch.topk( + expert_gate, top_k_experts = torch.topk( router_weights, self.expert_used_count, dim=-1 ) - router_weights /= router_weights.sum(dim=-1, keepdim=True) - router_weights = router_weights.to(ffn_input.dtype) + # router_weights /= router_weights.sum(dim=-1, keepdim=True) + # router_weights = router_weights.to(ffn_input.dtype) - moe_output = self.mix(ffn_input, top_k_experts) + moe_output = self.mix(ffn_input, top_k_experts, expert_gate) moe_output = moe_output.reshape(batch_size, sequence_length, feature_dim) moe_output = self.layer_output_norm(moe_output) diff --git a/sharktank/sharktank/models/llama/testing.py b/sharktank/sharktank/models/llama/testing.py index b63fd5d07..43071e8fb 100644 --- a/sharktank/sharktank/models/llama/testing.py +++ b/sharktank/sharktank/models/llama/testing.py @@ -60,7 +60,7 @@ def make_moe_block_theta(feature_dim=1024, ffn_dim=6144, num_experts=8) -> Theta return Theta( { "blk.0.ffn_gate_inp.weight": DefaultPrimitiveTensor( - data=make_rand_torch((feature_dim, ffn_dim)) + data=make_rand_torch((8, ffn_dim)) ), "blk.0.ffn_norm.weight": DefaultPrimitiveTensor( data=make_rand_torch((ffn_dim))