diff --git a/sharktank/sharktank/layers/ffn_moe_block.py b/sharktank/sharktank/layers/ffn_moe_block.py
index 471387853..626cad90d 100644
--- a/sharktank/sharktank/layers/ffn_moe_block.py
+++ b/sharktank/sharktank/layers/ffn_moe_block.py
@@ -31,32 +31,37 @@ def __init__(
         self.ffn_up = theta.tensor("ffn_up_exps", "weight")
         self.ffn_down = theta.tensor("ffn_down_exps", "weight")
 
-    #    def pre_matmul_gather(self, inputs, weights, experts):
-    #        inputs = inputs[:,:]
-    #        weights = weights[experts.reshape(-1), :, :]
-    #        matmul = torch.einsum("mk,mnk->mn", inputs, weights)
-    #        return matmul
     def pre_matmul_gather(self, inputs, weights, experts):
-        matmul = torch.einsum("mk,bnk->bmn", inputs, weights)
-
-        # Post mix the experts
-        oh = (
-            torch.nn.functional.one_hot(experts.reshape(-1), num_classes=8)
-            .transpose(0, 1)
-            .to(torch.float32)
-        )
-        output = torch.einsum("bm,bmn->mn", oh, matmul)
-        return output
+        inputs = inputs[:, :]
+        weights = weights[experts, :, :]
+        matmul = torch.einsum("mk,menk->men", inputs, weights)
+        return matmul
+
+    def bigger_mmg(self, inputs, weights, experts):
+        inputs = inputs[:, :]
+        weights = weights[experts, :, :]
+        matmul = torch.einsum("mek,menk->men", inputs, weights)
+        return matmul
+
+    # def pre_matmul_gather(self, inputs, weights, experts):
+    #    matmul = torch.einsum("mk,bnk->bmn", inputs, weights)
+    #
+    #        # Post mix the experts
+    #        oh = torch.nn.functional.one_hot(experts.reshape(-1), num_classes=8).transpose(0, 1).to(torch.float32)
+    #        output = torch.einsum("bm,bmn->mn", oh, matmul)
+    #        return output
 
     def forward(
         self,
         h: torch.Tensor,
         experts: torch.Tensor,
+        expert_gate: torch.Tensor,
     ):
         ffn_gate = F.silu(self.pre_matmul_gather(h, self.ffn_gate.as_torch(), experts))
         ffn_up = self.pre_matmul_gather(h, self.ffn_up, experts)
-        ffn_down = self.pre_matmul_gather(ffn_gate * ffn_up, self.ffn_down, experts)
-        return ffn_down
+        ffn_down = self.bigger_mmg(ffn_gate * ffn_up, self.ffn_down, experts)
+        ffn_down = torch.einsum("me,men->men", expert_gate, ffn_down)
+        return torch.sum(ffn_down, dim=1)
 
 
 class FFNMOE(ThetaLayer):
diff --git a/sharktank/sharktank/layers/mixture_of_experts_block.py b/sharktank/sharktank/layers/mixture_of_experts_block.py
index 49eaead95..317aa92b5 100644
--- a/sharktank/sharktank/layers/mixture_of_experts_block.py
+++ b/sharktank/sharktank/layers/mixture_of_experts_block.py
@@ -167,14 +167,14 @@ def forward(
         router_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
 
         # Select top k experts from router weights
-        router_weights, top_k_experts = torch.topk(
+        expert_gate, top_k_experts = torch.topk(
             router_weights, self.expert_used_count, dim=-1
         )
 
-        router_weights /= router_weights.sum(dim=-1, keepdim=True)
-        router_weights = router_weights.to(ffn_input.dtype)
+        # router_weights /= router_weights.sum(dim=-1, keepdim=True)
+        # router_weights = router_weights.to(ffn_input.dtype)
 
-        moe_output = self.mix(ffn_input, top_k_experts)
+        moe_output = self.mix(ffn_input, top_k_experts, expert_gate)
         moe_output = moe_output.reshape(batch_size, sequence_length, feature_dim)
 
         moe_output = self.layer_output_norm(moe_output)
diff --git a/sharktank/sharktank/models/llama/testing.py b/sharktank/sharktank/models/llama/testing.py
index b63fd5d07..43071e8fb 100644
--- a/sharktank/sharktank/models/llama/testing.py
+++ b/sharktank/sharktank/models/llama/testing.py
@@ -60,7 +60,7 @@ def make_moe_block_theta(feature_dim=1024, ffn_dim=6144, num_experts=8) -> Theta
     return Theta(
         {
             "blk.0.ffn_gate_inp.weight": DefaultPrimitiveTensor(
-                data=make_rand_torch((feature_dim, ffn_dim))
+                data=make_rand_torch((8, ffn_dim))
             ),
             "blk.0.ffn_norm.weight": DefaultPrimitiveTensor(
                 data=make_rand_torch((ffn_dim))