Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
[retiarii] support torch 1.8 and 1.9 (#3937)
Browse files Browse the repository at this point in the history
  • Loading branch information
QuanluZhang authored Jul 14, 2021
1 parent 542a660 commit 5fe2450
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 5 deletions.
2 changes: 1 addition & 1 deletion docs/en_US/NAS/QuickStart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ In this quick start tutorial, we use multi-trial NAS as an example to show how t

One-shot NAS tutorial can be found `here <./OneshotTrainer.rst>`__.

.. note:: Currently, PyTorch is the only supported framework by Retiarii, and we have only tested with **PyTorch 1.6 and 1.7**. This documentation assumes PyTorch context but it should also apply to other frameworks, that is in our future plan.
.. note:: Currently, PyTorch is the only supported framework by Retiarii, and we have only tested with **PyTorch 1.6 to 1.9**. This documentation assumes PyTorch context but it should also apply to other frameworks, that is in our future plan.

Define your Model Space
-----------------------
Expand Down
19 changes: 16 additions & 3 deletions nni/retiarii/operation_def/torch_op_def.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class PrimConstant(PyTorchOperation):
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
# TODO: refactor this part, maybe we can remove the code gen of prim::Constant
# TODO: deal with all the types
if self.parameters['type'] == 'None':
if self.parameters['type'] in ['None', 'NoneType']:
return f'{output} = None'
elif self.parameters['type'] in ('int', 'float', 'bool', 'int[]'):
return f'{output} = {self.parameters["value"]}'
Expand Down Expand Up @@ -238,7 +238,13 @@ def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_val

ManuallyChooseDef = {
'aten::flatten': [('start_dim', 'int', '0'), ('end_dim', 'int', '-1')],
'aten::split': [('split_size', 'int', 'None'), ('dim', 'int', '0')]
'aten::split': [('split_size', 'int', 'None'), ('dim', 'int', '0')],
# in v1.9 dtype is supported as input argument for view, but torch script does not support it
'aten::view': [('size', 'List[int]', 'None')],
# NOTE: dim supports different types: List[int], List[str], Optional[List[int]], now we only support the first two, refactor needed
# torch.std(input, dim, unbiased, keepdim=False, *, out=None) Tensor
# torch.std(input, unbiased) Tensor
'aten::std': [('dim', 'List[int]', 'None'), ('unbiased', 'bool', 'True'), ('keepdim', 'bool', 'False')]
}

TensorOpExceptions = {
Expand Down Expand Up @@ -426,4 +432,11 @@ class AtenAvgpool2d(PyTorchOperation):
# NOTE: it is not included in the above aten ops for unkown reason
_ori_type_name = ['aten::avg_pool2d']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
return f'{output} = F.avg_pool2d({", ".join(inputs)})'
return f'{output} = F.avg_pool2d({", ".join(inputs)})'

class AtenDet(PyTorchOperation):
# for torch 1.9
# NOTE: it is not included in the above aten ops, maybe because torch.det is alias for torch.linalg.det
_ori_type_name = ['aten::linalg_det']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
return f'{output} = torch.det({inputs[0]})'
2 changes: 1 addition & 1 deletion test/ut/retiarii/test_convert_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ def forward(self, input):
# NOTE: torch script gets an incorrect graph...
def test_optional_inputs_with_mixed_optionals(self):
class MixedModel(nn.Module):
def forward(self, x: 'Tensor', y: 'Tensor', z: 'Tensor'):
def forward(self, x, y, z):
if y is not None:
return x + y
if z is not None:
Expand Down

0 comments on commit 5fe2450

Please sign in to comment.