Skip to content

Commit

Permalink
test input consistency
Browse files Browse the repository at this point in the history
  • Loading branch information
sambhavnoobcoder committed Jan 22, 2025
1 parent 99aed7b commit 691a67d
Showing 1 changed file with 35 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,38 @@ def test_router_training_mode(self):
self.assertFalse(
torch.allclose(output1, output2),
"Outputs should differ in training mode due to jitter noise"
)
)

def test_expert_inputs_consistency(self):
"""Test that expert inputs are consistent and not affected by jitter"""
model = SwitchTransformersSparseMLP(self.config)
model.train() # Set to training mode to enable jitter

# Create input
hidden_states = torch.randn(2, 4, 32)

# Store expert inputs during forward pass
expert_inputs = []

def hook_fn(module, input, output):
expert_inputs.append(input[0].clone())

# Register forward hook on first expert
handle = model.experts.expert_0.register_forward_hook(hook_fn)

# Multiple forward passes
for _ in range(3):
model(hidden_states.clone())

# Remove hook
handle.remove()

# Verify all expert inputs are identical
for i in range(1, len(expert_inputs)):
self.assertTrue(
torch.allclose(expert_inputs[0], expert_inputs[i], atol=1e-5),
f"Expert inputs differ between run 0 and run {i}"
)

if __name__ == '__main__':
unittest.main()

0 comments on commit 691a67d

Please sign in to comment.