From 5fe245006a1003890720fe0debe3881cd0a31c79 Mon Sep 17 00:00:00 2001 From: QuanluZhang Date: Wed, 14 Jul 2021 17:15:23 +0800 Subject: [PATCH] [retiarii] support torch 1.8 and 1.9 (#3937) --- docs/en_US/NAS/QuickStart.rst | 2 +- nni/retiarii/operation_def/torch_op_def.py | 19 ++++++++++++++++--- test/ut/retiarii/test_convert_pytorch.py | 2 +- 3 files changed, 18 insertions(+), 5 deletions(-) diff --git a/docs/en_US/NAS/QuickStart.rst b/docs/en_US/NAS/QuickStart.rst index 01ec33f540..3dc3670ad4 100644 --- a/docs/en_US/NAS/QuickStart.rst +++ b/docs/en_US/NAS/QuickStart.rst @@ -12,7 +12,7 @@ In this quick start tutorial, we use multi-trial NAS as an example to show how t One-shot NAS tutorial can be found `here <./OneshotTrainer.rst>`__. -.. note:: Currently, PyTorch is the only supported framework by Retiarii, and we have only tested with **PyTorch 1.6 and 1.7**. This documentation assumes PyTorch context but it should also apply to other frameworks, that is in our future plan. +.. note:: Currently, PyTorch is the only supported framework by Retiarii, and we have only tested with **PyTorch 1.6 to 1.9**. This documentation assumes PyTorch context but it should also apply to other frameworks, that is in our future plan. Define your Model Space ----------------------- diff --git a/nni/retiarii/operation_def/torch_op_def.py b/nni/retiarii/operation_def/torch_op_def.py index bb97069e63..313a5558af 100644 --- a/nni/retiarii/operation_def/torch_op_def.py +++ b/nni/retiarii/operation_def/torch_op_def.py @@ -59,7 +59,7 @@ class PrimConstant(PyTorchOperation): def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str: # TODO: refactor this part, maybe we can remove the code gen of prim::Constant # TODO: deal with all the types - if self.parameters['type'] == 'None': + if self.parameters['type'] in ['None', 'NoneType']: return f'{output} = None' elif self.parameters['type'] in ('int', 'float', 'bool', 'int[]'): return f'{output} = {self.parameters["value"]}' @@ -238,7 +238,13 @@ def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_val ManuallyChooseDef = { 'aten::flatten': [('start_dim', 'int', '0'), ('end_dim', 'int', '-1')], - 'aten::split': [('split_size', 'int', 'None'), ('dim', 'int', '0')] + 'aten::split': [('split_size', 'int', 'None'), ('dim', 'int', '0')], + # in v1.9 dtype is supported as input argument for view, but torch script does not support it + 'aten::view': [('size', 'List[int]', 'None')], + # NOTE: dim supports different types: List[int], List[str], Optional[List[int]], now we only support the first two, refactor needed + # torch.std(input, dim, unbiased, keepdim=False, *, out=None) Tensor + # torch.std(input, unbiased) Tensor + 'aten::std': [('dim', 'List[int]', 'None'), ('unbiased', 'bool', 'True'), ('keepdim', 'bool', 'False')] } TensorOpExceptions = { @@ -426,4 +432,11 @@ class AtenAvgpool2d(PyTorchOperation): # NOTE: it is not included in the above aten ops for unkown reason _ori_type_name = ['aten::avg_pool2d'] def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str: - return f'{output} = F.avg_pool2d({", ".join(inputs)})' \ No newline at end of file + return f'{output} = F.avg_pool2d({", ".join(inputs)})' + +class AtenDet(PyTorchOperation): + # for torch 1.9 + # NOTE: it is not included in the above aten ops, maybe because torch.det is alias for torch.linalg.det + _ori_type_name = ['aten::linalg_det'] + def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str: + return f'{output} = torch.det({inputs[0]})' \ No newline at end of file diff --git a/test/ut/retiarii/test_convert_pytorch.py b/test/ut/retiarii/test_convert_pytorch.py index dbcf1acd31..51857b6815 100644 --- a/test/ut/retiarii/test_convert_pytorch.py +++ b/test/ut/retiarii/test_convert_pytorch.py @@ -375,7 +375,7 @@ def forward(self, input): # NOTE: torch script gets an incorrect graph... def test_optional_inputs_with_mixed_optionals(self): class MixedModel(nn.Module): - def forward(self, x: 'Tensor', y: 'Tensor', z: 'Tensor'): + def forward(self, x, y, z): if y is not None: return x + y if z is not None: