From eb452b06c8170dfcf7b5bb8a272321b894e07d47 Mon Sep 17 00:00:00 2001 From: valmat07 Date: Mon, 26 Dec 2022 18:08:31 +0300 Subject: [PATCH] Add implementation for pytorch weight normalization --- python/tvm/relay/frontend/pytorch.py | 15 ++++++++++++ tests/python/frontend/pytorch/test_forward.py | 24 +++++++++++++++++++ 2 files changed, 39 insertions(+) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 30f14b490b1b..87264da139a1 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -3500,6 +3500,20 @@ def multinomial(self, inputs, input_types): _, indices = _expr.TupleWrapper(output, 2) return indices + def weight_norm(self, inputs, input_types): + weight_v, weight_g = inputs[0], inputs[1] + dim = inputs[2] + dtype = input_types[0] + order = 2.0 + reci_order = _expr.const(1.0 / order, dtype=dtype) + order = _expr.const(order) + + norm_v = _op.power( + _op.reduce.sum(_op.power(_op.abs(weight_v), order), axis=dim, exclude=2, keepdims=True), + reci_order, + ) + return weight_g * (weight_v / norm_v) + # Operator mappings def create_convert_map(self): self.convert_map = { @@ -3766,6 +3780,7 @@ def create_convert_map(self): "aten::__lshift__": self.make_elemwise("left_shift"), "aten::__rshift__": self.make_elemwise("right_shift"), "aten::multinomial": self.multinomial, + "aten::_weight_norm": self.weight_norm, } def update_convert_map(self, custom_map): diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 36bb5bede475..161b1766eeaa 100755 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -5038,5 +5038,29 @@ def _test_multinomial(num_samples): ) +def test_weight_norm(): + """Test for atten::_weight_norm""" + in_channels = 32 + out_channels = 64 + input_data_conv = torch.rand((1, in_channels, 32, 32)).float() + + conv_wn = torch.nn.utils.weight_norm(torch.nn.Conv2d(in_channels, out_channels, kernel_size=3)) + verify_model(conv_wn.eval().float(), input_data_conv) + + conv_wn_groups = torch.nn.utils.weight_norm( + torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, groups=2) + ) + verify_model(conv_wn_groups.eval().float(), input_data_conv) + + conv_wn = torch.nn.utils.weight_norm( + torch.nn.Conv2d(in_channels, out_channels, kernel_size=3), dim=1 + ) + verify_model(conv_wn.eval().float(), input_data_conv) + + linear_wn = torch.nn.utils.weight_norm(torch.nn.Linear(in_channels, out_channels)) + input_data_linear = torch.rand((128, in_channels)).float() + verify_model(linear_wn.eval().float(), input_data_linear) + + if __name__ == "__main__": tvm.testing.main()