diff --git a/nni/compression/pytorch/speedup/jit_translate.py b/nni/compression/pytorch/speedup/jit_translate.py index ac051c73af..f0e1098093 100644 --- a/nni/compression/pytorch/speedup/jit_translate.py +++ b/nni/compression/pytorch/speedup/jit_translate.py @@ -253,6 +253,13 @@ def squeeze_python(node, speedup): new_squeeze = partial(torch.squeeze, dim=dim) return new_squeeze +def unsqueeze_python(node, speedup): + c_node = node.key_node + inputs = list(c_node.inputs()) + dim = parse_constant(inputs[1], speedup) + new_unsqueeze = partial(torch.unsqueeze, dim=dim) + return new_unsqueeze + ########################################################## # Split Line # Following module/functions cannot be translated into a @@ -517,6 +524,7 @@ def forward(self, *args): 'aten::upsample_bilinear2d': upsample_bilinear2d_python, 'aten::exp': exp_python, 'aten::squeeze': squeeze_python, + 'aten::unsqueeze': unsqueeze_python, 'prim::TupleUnpack': tupleunpack_python, 'prim::ListUnpack': tupleunpack_python, 'prim::NumToTensor': num2tensor_python,