diff --git a/nni/retiarii/converter/op_types.py b/nni/retiarii/converter/op_types.py index 7d25d14d9e..4c3d283557 100644 --- a/nni/retiarii/converter/op_types.py +++ b/nni/retiarii/converter/op_types.py @@ -29,6 +29,7 @@ class OpTypeName(str, Enum): 'aten::cat': 'Cat', 'aten::size': 'Size', 'aten::view': 'View', + 'aten::reshape': 'Reshape', 'aten::eq': 'Eq', 'aten::Bool': 'Bool', 'aten::empty': 'Empty', diff --git a/nni/retiarii/operation.py b/nni/retiarii/operation.py index b1d4303d05..712d10b712 100644 --- a/nni/retiarii/operation.py +++ b/nni/retiarii/operation.py @@ -150,6 +150,9 @@ def to_forward_code(self, field: str, output: str, inputs: List[str]) -> str: elif self.type == 'aten::view': assert len(inputs) == 2 return f'{output} = {inputs[0]}.view({inputs[1]})' + elif self.type == 'aten::reshape': + assert len(inputs) == 2 + return f'{output} = {inputs[0]}.reshape({inputs[1]})' elif self.type == 'aten::slice': raise RuntimeError('not supposed to have aten::slice operation') elif self.type == 'aten::Bool':