From ffc4d3c983a639ffcd190420639a9f751e8ae4b9 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Tue, 30 Jan 2024 14:19:35 -0800 Subject: [PATCH] [onnx] Import `onnx` constants as `onnx.Constant` instead of literals To handle the conversion from raw bytes to `DenseElementsAttr` we need to handle the endianness conversion during `torch-onnx-to-torch`. Therefore when importing `onnx.Constant` it is better to represent using the `onnx` constant operation so that only one location requires the endianness correction. --- python/torch_mlir/extras/onnx_importer.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/python/torch_mlir/extras/onnx_importer.py b/python/torch_mlir/extras/onnx_importer.py index 59a2682bbba9..c651f79b15fe 100644 --- a/python/torch_mlir/extras/onnx_importer.py +++ b/python/torch_mlir/extras/onnx_importer.py @@ -343,10 +343,14 @@ def import_initializer(self, initializer: onnx.TensorProto, extern_name: str = N with InsertionPoint(self._b), Location.name(iname): value_attr = self._cc.tensor_proto_to_attr(initializer) vtensor_type = self._cc.tensor_proto_to_type(initializer) + attrs = { + "name": StringAttr.get(f"onnx.Constant"), + "torch.onnx.value": value_attr, + } literal_op = Operation.create( - name="torch.vtensor.literal", + name="torch.operator", results=[vtensor_type], - attributes={"value": value_attr}, + attributes=attrs, ) self._nv_map[iname] = literal_op.result return literal_op.result