Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
Fix KeyError: 'aten::reshape' on calling torch.reshape (#3341)
Browse files Browse the repository at this point in the history
  • Loading branch information
tczhangzhi authored Feb 2, 2021
1 parent 533f2ef commit be3a696
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 0 deletions.
1 change: 1 addition & 0 deletions nni/retiarii/converter/op_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class OpTypeName(str, Enum):
'aten::cat': 'Cat',
'aten::size': 'Size',
'aten::view': 'View',
'aten::reshape': 'Reshape',
'aten::eq': 'Eq',
'aten::Bool': 'Bool',
'aten::empty': 'Empty',
Expand Down
3 changes: 3 additions & 0 deletions nni/retiarii/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,9 @@ def to_forward_code(self, field: str, output: str, inputs: List[str]) -> str:
elif self.type == 'aten::view':
assert len(inputs) == 2
return f'{output} = {inputs[0]}.view({inputs[1]})'
elif self.type == 'aten::reshape':
assert len(inputs) == 2
return f'{output} = {inputs[0]}.reshape({inputs[1]})'
elif self.type == 'aten::slice':
raise RuntimeError('not supposed to have aten::slice operation')
elif self.type == 'aten::Bool':
Expand Down

0 comments on commit be3a696

Please sign in to comment.