From 53161f16614b51102c578f414119b62a8b57e310 Mon Sep 17 00:00:00 2001 From: Ningxin Zheng Date: Mon, 19 Jul 2021 11:52:42 +0000 Subject: [PATCH] Support UnSqueeze --- nni/compression/pytorch/speedup/jit_translate.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/nni/compression/pytorch/speedup/jit_translate.py b/nni/compression/pytorch/speedup/jit_translate.py index ac051c73af..f0e1098093 100644 --- a/nni/compression/pytorch/speedup/jit_translate.py +++ b/nni/compression/pytorch/speedup/jit_translate.py @@ -253,6 +253,13 @@ def squeeze_python(node, speedup): new_squeeze = partial(torch.squeeze, dim=dim) return new_squeeze +def unsqueeze_python(node, speedup): + c_node = node.key_node + inputs = list(c_node.inputs()) + dim = parse_constant(inputs[1], speedup) + new_unsqueeze = partial(torch.unsqueeze, dim=dim) + return new_unsqueeze + ########################################################## # Split Line # Following module/functions cannot be translated into a @@ -517,6 +524,7 @@ def forward(self, *args): 'aten::upsample_bilinear2d': upsample_bilinear2d_python, 'aten::exp': exp_python, 'aten::squeeze': squeeze_python, + 'aten::unsqueeze': unsqueeze_python, 'prim::TupleUnpack': tupleunpack_python, 'prim::ListUnpack': tupleunpack_python, 'prim::NumToTensor': num2tensor_python,