Skip to content

Commit

Permalink
final fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
sambhavnoobcoder authored Jan 27, 2025
1 parent a49fb0b commit d37e04f
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions tests/models/mask2former/test_initialization_mask2former.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,23 @@ class TestMask2FormerInitialization(unittest.TestCase):
def setUpClass(cls):
cls.config = Mask2FormerConfig()
cls.model = Mask2FormerModel(cls.config)

def test_embedding_initialization(self):
"""Test that embeddings are initialized with std=1.0 (PyTorch default)"""
for name, module in self.model.named_modules():
if isinstance(module, torch.nn.Embedding):
# Calculate empirical standard deviation
std = torch.std(module.weight.data).item()
self.assertAlmostEqual(std, 1.0, places=1)

def test_mlp_bias_initialization(self):
"""Test that MLP biases are properly initialized"""
for name, module in self.model.named_modules():
if isinstance(module, Mask2FormerMaskedAttentionDecoderLayer):
for param in module.parameters():
if param.dim() == 1: # Bias terms
self.assertFalse(torch.all(param.data == 0))

def test_multiscale_deformable_attention(self):
"""Test that multiscale deformable attention is properly initialized"""
for name, module in self.model.named_modules():
Expand Down

0 comments on commit d37e04f

Please sign in to comment.