Skip to content

Commit

Permalink
[ADD] Unit test for FlattenEmbedding
Browse files Browse the repository at this point in the history
  • Loading branch information
f-dangel committed Nov 4, 2024
1 parent 7fa17ea commit 87a3c1d
Showing 1 changed file with 36 additions and 4 deletions.
40 changes: 36 additions & 4 deletions test/test_combiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@

from pytest import mark
from torch import allclose, device, manual_seed, rand, zeros
from torch.nn import Linear
from torch.nn import Embedding, Linear

from sirfshampoo.combiner import LinearWeightBias, PerParameter
from sirfshampoo.combiner import FlattenEmbedding, LinearWeightBias, PerParameter


@mark.parametrize("dev", DEVICES, ids=DEVICE_IDS)
Expand Down Expand Up @@ -66,8 +66,9 @@ def test_LinearWeightBias_identify(dev: device):
Args:
dev: Device to run the test on.
"""
model = Linear(2, 3).to(dev)
assert LinearWeightBias().identify(model) == [[model.weight, model.bias]]
for linear_cls in LinearWeightBias.LINEAR_CLS:
model = linear_cls(2, 3).to(dev)
assert LinearWeightBias().identify(model) == [[model.weight, model.bias]]


@mark.parametrize("dev", DEVICES, ids=DEVICE_IDS)
Expand Down Expand Up @@ -105,3 +106,34 @@ def test_LinearWeightBias_group_and_ungroup(dev: device):
assert allclose(grouped, tensor)
(W_ungrouped, b_ungrouped) = LinearWeightBias().ungroup(grouped, shapes)
assert allclose(W_ungrouped, W) and allclose(b_ungrouped, b)


@mark.parametrize("dev", DEVICES, ids=DEVICE_IDS)
def test_FlattenEmbedding_identify(dev: device):
"""Test parameter identification of `FlattenEmbedding` class.
Args:
dev: Device to run the test on.
"""
for embedding_cls in FlattenEmbedding.EMBEDDING_CLS:
model = embedding_cls(2, 3).to(dev)
assert FlattenEmbedding().identify(model) == [[model.weight]]


@mark.parametrize("dev", DEVICES, ids=DEVICE_IDS)
def test_FlattenEmbedding_group_and_ungroup(dev: device):
"""Test tensor (un-)grouping of `FlattenEmbedding` class.
Args:
dev: Device to run the test on.
"""
manual_seed(0)

tensor = rand(5, 10, device=dev)
tensor_flat = tensor.flatten()
shapes = [tensor.shape]

grouped = FlattenEmbedding().group([tensor])
assert allclose(grouped, tensor_flat)
(ungrouped,) = FlattenEmbedding().ungroup(grouped, shapes)
assert allclose(ungrouped, tensor)

0 comments on commit 87a3c1d

Please sign in to comment.