diff --git a/nni/compression/pytorch/compressor.py b/nni/compression/pytorch/compressor.py index 01b8bb24e4..23df0e7353 100644 --- a/nni/compression/pytorch/compressor.py +++ b/nni/compression/pytorch/compressor.py @@ -374,7 +374,8 @@ def _wrap_modules(self, layer, config): wrapper.to(layer.module.weight.device) return wrapper - def export_model(self, model_path, mask_path=None, onnx_path=None, input_shape=None, device=None): + def export_model(self, model_path, mask_path=None, onnx_path=None, input_shape=None, device=None, + dummy_input=None, opset_version=None): """ Export pruned model weights, masks and onnx model(optional) @@ -387,10 +388,21 @@ def export_model(self, model_path, mask_path=None, onnx_path=None, input_shape=N onnx_path : str (optional) path to save onnx model input_shape : list or tuple - input shape to onnx model + input shape to onnx model, used for creating a dummy input tensor for torch.onnx.export + if the input has a complex structure (e.g., a tuple), please directly create the input and + pass it to dummy_input instead + note: this argument is deprecated and will be removed; please use dummy_input instead device : torch.device - device of the model, used to place the dummy input tensor for exporting onnx file. + device of the model, where to place the dummy input tensor for exporting onnx file; the tensor is placed on cpu if ```device``` is None + only useful when both onnx_path and input_shape are passed + note: this argument is deprecated and will be removed; please use dummy_input instead + dummy_input: torch.Tensor or tuple + dummy input to the onnx model; used when input_shape is not enough to specify dummy input + user should ensure that the dummy_input is on the same device as the model + opset_version: int + opset_version parameter for torch.onnx.export; only useful when onnx_path is not None + if not passed, torch.onnx.export will use its default opset_version """ assert model_path is not None, 'model_path must be specified' mask_dict = {} @@ -411,17 +423,31 @@ def export_model(self, model_path, mask_path=None, onnx_path=None, input_shape=N torch.save(self.bound_model.state_dict(), model_path) _logger.info('Model state_dict saved to %s', model_path) + if mask_path is not None: torch.save(mask_dict, mask_path) _logger.info('Mask dict saved to %s', mask_path) + if onnx_path is not None: - assert input_shape is not None, 'input_shape must be specified to export onnx model' - # input info needed - if device is None: - device = torch.device('cpu') - input_data = torch.Tensor(*input_shape) - torch.onnx.export(self.bound_model, input_data.to(device), onnx_path) - _logger.info('Model in onnx with input shape %s saved to %s', input_data.shape, onnx_path) + assert input_shape is not None or dummy_input is not None,\ + 'input_shape or dummy_input must be specified to export onnx model' + # create dummy_input using input_shape if input_shape is not passed + if dummy_input is None: + _logger.warning("""The argument input_shape and device will be removed in the future. + Please create a dummy input and pass it to dummy_input instead.""") + if device is None: + device = torch.device('cpu') + input_data = torch.Tensor(*input_shape).to(device) + else: + input_data = dummy_input + if opset_version is not None: + torch.onnx.export(self.bound_model, input_data, onnx_path, opset_version=opset_version) + else: + torch.onnx.export(self.bound_model, input_data, onnx_path) + if dummy_input is None: + _logger.info('Model in onnx with input shape %s saved to %s', input_data.shape, onnx_path) + else: + _logger.info('Model in onnx saved to %s', onnx_path) self._wrap_model()