diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index e8c36d8cad54..99a3985a2993 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -1571,7 +1571,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( return success(); }); patterns.onOp( - "ConstantOfShape", 20, + "ConstantOfShape", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value shape; @@ -1582,15 +1582,14 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( auto shapeSizes = dyn_cast(shape.getType()).getSizes(); SmallVector dimList; - SmallVector selectSizes; - selectSizes.push_back(1); Torch::BaseTensorType shapeType = shape.getType().cast(); - Type selectResultType = shapeType.getWithSizesAndDtype( - llvm::ArrayRef(selectSizes), shapeType.getOptionalDtype()); + Type selectResultType = rewriter.getType( + ArrayRef({}), shapeType.getOptionalDtype()); Value zero = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); + for (int i = 0; i < shapeSizes[0]; i++) { Value selectIndex = rewriter.create( binder.getLoc(), rewriter.getType(), @@ -1601,6 +1600,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.getLoc(), rewriter.getType(), extract); dimList.push_back(dim); } + Value dimValueList = rewriter.create( binder.getLoc(), Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), @@ -1609,7 +1609,6 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( // Get fill_value if it is present. // Assumption : resultDType and value attr type match. - Value value_const; auto attr = binder.op->getAttr("torch.onnx.value"); auto resultDType = resultType.getDtype(); @@ -1620,34 +1619,58 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( resultType.toBuiltinTensor().clone(resultDType), rewriter.getFloatAttr(resultDType, 0.0)); } - if (!isa(attr)) { - return rewriter.notifyMatchFailure( - binder.op, "`value` attr needs to be a tensor."); + + // If its a dense resource attr we need to convert to a dense type: + if (DenseResourceElementsAttr rattr = + attr.dyn_cast_or_null()) { + // Bytes are stored in little endian order. Big endian support will + // require swizzling. + if (!Endian::little) { + binder.op->emitError( + "unimplemented: importing on big endian systems"); + return failure(); + } + + auto ty = cast(rattr.getType()); + auto ptr = rattr.getRawHandle().getBlob()->getData(); + auto denseAttr = DenseElementsAttr::getFromRawBuffer(ty, ptr); + attr = dyn_cast_or_null(denseAttr); + } + + Attribute splattr; + if (isa(attr)) { + auto denseAttr = attr.cast(); + splattr = denseAttr.getSplatValue(); } - auto denseAttr = attr.cast(); - auto denseAttrEleType = denseAttr.getElementType(); - if (!isa(denseAttrEleType)) { + if (!isa(splattr)) { return rewriter.notifyMatchFailure( binder.op, "`value` attr tensor only supports types int and float for now."); } - // Create constant op for value - if (denseAttrEleType.isa()) { - int64_t intVal = denseAttr.getSplatValue().getSInt(); - value_const = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(intVal)); - } - if (denseAttrEleType.isa()) { - float floatVal = - denseAttr.getSplatValue().getValue().convertToFloat(); - value_const = rewriter.create( - binder.getLoc(), rewriter.getF64FloatAttr(floatVal)); + Value splatvalue; + if (auto intattr = dyn_cast(splattr)) { + IntegerType intty = cast(intattr.getType()); + int64_t value; + if (intty.isUnsignedInteger()) { + value = intattr.getUInt(); + } else if (intty.isSignedInteger()) { + value = intattr.getSInt(); + } else { + value = intattr.getInt(); + } + splatvalue = + rewriter.create(binder.getLoc(), value); } + if (auto fpattr = dyn_cast(splattr)) + splatvalue = rewriter.create( + binder.getLoc(), + rewriter.getF64FloatAttr(fpattr.getValueAsDouble())); + rewriter.replaceOpWithNewOp( - binder.op, resultType, dimValueList, value_const, /*dtype=*/noneVal, + binder.op, resultType, dimValueList, splatvalue, /*dtype=*/noneVal, /*layout=*/noneVal, /*device=*/noneVal, /*pin_memory=*/noneVal); return success(); }); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index e600a6be8a52..e749b5834cc6 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1825,23 +1825,6 @@ "ElementwiseClampTensorFloatModule_basic", "ElementwiseClampTensorInt8Module_basic", "ElementwiseClampTensorIntModule_basic", - "EmptyLikeMemoryFormatModule_basic", - "EmptyLikeModule_defaultDtype", - "EmptyLikeModule_falsePinMemory", - "EmptyLikeModule_float", - "EmptyLikeModule_int", - "Fill_TensorFloat32WithFloat32_basic", - "Fill_TensorFloat32WithFloat64_basic", - "Fill_TensorFloat32WithInt64_basic", - "Fill_TensorFloat64WithFloat32_basic", - "Fill_TensorFloat64WithFloat64_basic", - "Fill_TensorFloat64WithInt64_basic", - "FullLikeModuleDefaultDtype_basic", - "FullLikeModuleFalsePinMemory_basic", - "FullLikeModuleFloat2D_basic", - "FullLikeModuleFloat3D_basic", - "FullLikeModuleInt2D_basic", - "FullLikeModuleInt3D_basic", "HBC_basic", "IndexPut1DFloatAccumulateModule_basic", "IndexPut1DIntAccumulateModule_basic", @@ -1856,10 +1839,6 @@ "IndexPutHackedTwin3DFloatAccumulateModule_basic", "IndexPutHackedTwin3DIntAccumulateModule_basic", "NormalizeModule_basic", - "OnesLikeModule_defaultDtype", - "OnesLikeModule_falsePinMemory", - "OnesLikeModule_float", - "OnesLikeModule_int", "PadWithNoneValModule_basic", "QuantizedMLP_basic", "RandModule_basic", @@ -1875,13 +1854,6 @@ "TileSmallDimsSizeModule_basic", "UpSampleNearest2dDynamicSize_basic", "UpSampleNearest2dStaticSize_basic", - "ZeroFloat32Module_basic", - "ZeroInt32Module_basic", - "ZeroInt64Module_basic", - "ZerosLikeModule_defaultDtype", - "ZerosLikeModule_falsePinMemory", - "ZerosLikeModule_float", - "ZerosLikeModule_int", # Failure - onnx_lowering "AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic", diff --git a/python/torch_mlir/extras/onnx_importer.py b/python/torch_mlir/extras/onnx_importer.py index c62324832520..a0cfbf26ed30 100644 --- a/python/torch_mlir/extras/onnx_importer.py +++ b/python/torch_mlir/extras/onnx_importer.py @@ -408,37 +408,6 @@ def _handle_node_Constant(self, node: onnx.NodeProto) -> bool: self._gi.initializer_map[const_name] = value_proto.t return True - def _handle_node_ConstantOfShape(self, node: onnx.NodeProto) -> bool: - # This op is special: It has an input of the shape, and in full generality - # could involve eager production of constants of variable size. In - # practice, the DNN profile for ONNX makes this very difficult to do - # and we hard-assert that the input can be resolved to an immediate - # value. - assert len(node.input) == 1 - assert len(node.output) == 1 - shape = self._get_immediate_tensor(node.input[0]).astype(np.int64) - value_proto = _get_attr(node, "value") - assert value_proto.type == onnx.AttributeProto.AttributeType.TENSOR - tensor_proto = value_proto.t - element_type = self._cc.tensor_element_type(tensor_proto.data_type) - vtensor_type = self._cc.get_vtensor_type(tuple(shape), element_type) - assert len(tensor_proto.dims) == 1 and tensor_proto.dims[0] == 1 - try: - cb = ELEM_TYPE_SPLAT_TENSOR_PROTO_CB[tensor_proto.data_type] - except KeyError: - raise OnnxImportError( - f"Unhandled splat type for ConstantOfShape: {node} (possible missing mapping in ELEM_TYPE_SPLAT_TENSOR_PROTO_CB)" - ) - value_attr = cb(tensor_proto, tuple(shape)) - literal_op = Operation.create( - name="torch.vtensor.literal", - results=[vtensor_type], - attributes={"value": value_attr}, - ) - self._nv_map[node.output[0]] = literal_op.result - return True - - class ContextCache: """Caches per-context lookups of various things.""" diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir index 3e4a476dbfbb..525583b7660e 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir @@ -1539,14 +1539,14 @@ func.func @test_constant_of_shape_dense_float_default() -> !torch.vtensor<[2,3,4 // CHECK: %[[SHAPE_CST:.*]] = torch.vtensor.literal(dense<[2, 3, 4]> : tensor<3xsi64>) : !torch.vtensor<[3],si64> // CHECK: %[[INT0:.*]] = torch.constant.int 0 // CHECK: %[[INT0_0:.*]] = torch.constant.int 0 - // CHECK: %[[EXTRACT_0:.*]] = torch.aten.select.int %[[SHAPE_CST]], %[[INT0]], %[[INT0_0]] : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: %[[ELE_0:.*]] = torch.aten.item %[[EXTRACT_0]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[EXTRACT_0:.*]] = torch.aten.select.int %[[SHAPE_CST]], %[[INT0]], %[[INT0_0]] : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + // CHECK: %[[ELE_0:.*]] = torch.aten.item %[[EXTRACT_0]] : !torch.vtensor<[],si64> -> !torch.int // CHECK: %[[INT1:.*]] = torch.constant.int 1 - // CHECK: %[[EXTRACT_1:.*]] = torch.aten.select.int %[[SHAPE_CST]], %[[INT0]], %[[INT1]] : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: %[[ELE_1:.*]] = torch.aten.item %[[EXTRACT_1]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[EXTRACT_1:.*]] = torch.aten.select.int %[[SHAPE_CST]], %[[INT0]], %[[INT1]] : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + // CHECK: %[[ELE_1:.*]] = torch.aten.item %[[EXTRACT_1]] : !torch.vtensor<[],si64> -> !torch.int // CHECK: %[[INT2:.*]] = torch.constant.int 2 - // CHECK: %[[EXTRACT_2:.*]] = torch.aten.select.int %[[SHAPE_CST]], %[[INT0]], %[[INT2]] : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: %[[ELE_2:.*]] = torch.aten.item %[[EXTRACT_2]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[EXTRACT_2:.*]] = torch.aten.select.int %[[SHAPE_CST]], %[[INT0]], %[[INT2]] : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + // CHECK: %[[ELE_2:.*]] = torch.aten.item %[[EXTRACT_2]] : !torch.vtensor<[],si64> -> !torch.int // CHECK: %[[DIM_LIST:.*]] = torch.prim.ListConstruct %[[ELE_0]], %[[ELE_1]], %[[ELE_2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list // CHECK: %[[NONE:.*]] = torch.constant.none // CHECK: %[[FILL_VAL:.*]] = torch.constant.float 0.000000e+00 @@ -1563,14 +1563,14 @@ func.func @test_constant_of_shape_dense_float_cst() -> !torch.vtensor<[2,3,4], f // CHECK: %[[SHAPE_CST:.*]] = torch.vtensor.literal(dense<[2, 3, 4]> : tensor<3xsi64>) : !torch.vtensor<[3],si64> // CHECK: %[[INT0:.*]] = torch.constant.int 0 // CHECK: %[[INT0_0:.*]] = torch.constant.int 0 - // CHECK: %[[EXTRACT_0:.*]] = torch.aten.select.int %[[SHAPE_CST]], %[[INT0]], %[[INT0_0]] : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: %[[ELE_0:.*]] = torch.aten.item %[[EXTRACT_0]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[EXTRACT_0:.*]] = torch.aten.select.int %[[SHAPE_CST]], %[[INT0]], %[[INT0_0]] : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + // CHECK: %[[ELE_0:.*]] = torch.aten.item %[[EXTRACT_0]] : !torch.vtensor<[],si64> -> !torch.int // CHECK: %[[INT1:.*]] = torch.constant.int 1 - // CHECK: %[[EXTRACT_1:.*]] = torch.aten.select.int %[[SHAPE_CST]], %[[INT0]], %[[INT1]] : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: %[[ELE_1:.*]] = torch.aten.item %[[EXTRACT_1]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[EXTRACT_1:.*]] = torch.aten.select.int %[[SHAPE_CST]], %[[INT0]], %[[INT1]] : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + // CHECK: %[[ELE_1:.*]] = torch.aten.item %[[EXTRACT_1]] : !torch.vtensor<[],si64> -> !torch.int // CHECK: %[[INT2:.*]] = torch.constant.int 2 - // CHECK: %[[EXTRACT_2:.*]] = torch.aten.select.int %[[SHAPE_CST]], %[[INT0]], %[[INT2]] : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: %[[ELE_2:.*]] = torch.aten.item %[[EXTRACT_2]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[EXTRACT_2:.*]] = torch.aten.select.int %[[SHAPE_CST]], %[[INT0]], %[[INT2]] : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + // CHECK: %[[ELE_2:.*]] = torch.aten.item %[[EXTRACT_2]] : !torch.vtensor<[],si64> -> !torch.int // CHECK: %[[DIM_LIST:.*]] = torch.prim.ListConstruct %[[ELE_0]], %[[ELE_1]], %[[ELE_2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list // CHECK: %[[NONE:.*]] = torch.constant.none // CHECK: %[[FILL_VAL:.*]] = torch.constant.float 3.4000000953674316 @@ -1587,14 +1587,14 @@ func.func @test_constant_of_shape_dense_int_cst() -> !torch.vtensor<[2,3,4], si6 // CHECK: %[[SHAPE_CST:.*]] = torch.vtensor.literal(dense<[2, 3, 4]> : tensor<3xsi64>) : !torch.vtensor<[3],si64> // CHECK: %[[INT0:.*]] = torch.constant.int 0 // CHECK: %[[INT0_0:.*]] = torch.constant.int 0 - // CHECK: %[[EXTRACT_0:.*]] = torch.aten.select.int %[[SHAPE_CST]], %[[INT0]], %[[INT0_0]] : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: %[[ELE_0:.*]] = torch.aten.item %[[EXTRACT_0]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[EXTRACT_0:.*]] = torch.aten.select.int %[[SHAPE_CST]], %[[INT0]], %[[INT0_0]] : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + // CHECK: %[[ELE_0:.*]] = torch.aten.item %[[EXTRACT_0]] : !torch.vtensor<[],si64> -> !torch.int // CHECK: %[[INT1:.*]] = torch.constant.int 1 - // CHECK: %[[EXTRACT_1:.*]] = torch.aten.select.int %[[SHAPE_CST]], %[[INT0]], %[[INT1]] : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: %[[ELE_1:.*]] = torch.aten.item %[[EXTRACT_1]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[EXTRACT_1:.*]] = torch.aten.select.int %[[SHAPE_CST]], %[[INT0]], %[[INT1]] : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + // CHECK: %[[ELE_1:.*]] = torch.aten.item %[[EXTRACT_1]] : !torch.vtensor<[],si64> -> !torch.int // CHECK: %[[INT2:.*]] = torch.constant.int 2 - // CHECK: %[[EXTRACT_2:.*]] = torch.aten.select.int %[[SHAPE_CST]], %[[INT0]], %[[INT2]] : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: %[[ELE_2:.*]] = torch.aten.item %[[EXTRACT_2]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[EXTRACT_2:.*]] = torch.aten.select.int %[[SHAPE_CST]], %[[INT0]], %[[INT2]] : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + // CHECK: %[[ELE_2:.*]] = torch.aten.item %[[EXTRACT_2]] : !torch.vtensor<[],si64> -> !torch.int // CHECK: %[[DIM_LIST:.*]] = torch.prim.ListConstruct %[[ELE_0]], %[[ELE_1]], %[[ELE_2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list // CHECK: %[[NONE:.*]] = torch.constant.none // CHECK: %[[FILL_VAL:.*]] = torch.constant.int 3 diff --git a/test/python/onnx_importer/import_smoke_test.py b/test/python/onnx_importer/import_smoke_test.py index 708324e72db6..22d460050cae 100644 --- a/test/python/onnx_importer/import_smoke_test.py +++ b/test/python/onnx_importer/import_smoke_test.py @@ -102,9 +102,6 @@ "node_test_castlike_FLOAT_to_STRING_model", "node_test_castlike_STRING_to_FLOAT_expanded_model", "node_test_castlike_STRING_to_FLOAT_model", - "node_test_constantofshape_float_ones_model", - "node_test_constantofshape_int_shape_zero_model", - "node_test_constantofshape_int_zeros_model", "node_test_dequantizelinear_e4m3fn_model", "node_test_dequantizelinear_e4m3fn_zero_point_model", "node_test_dequantizelinear_e5m2_model", @@ -118,44 +115,6 @@ "node_test_if_model", "node_test_if_opt_model", "node_test_if_seq_model", - "node_test_layer_normalization_2d_axis0_expanded_model", - "node_test_layer_normalization_2d_axis0_expanded_ver18_model", - "node_test_layer_normalization_2d_axis1_expanded_model", - "node_test_layer_normalization_2d_axis1_expanded_ver18_model", - "node_test_layer_normalization_2d_axis_negative_1_expanded_model", - "node_test_layer_normalization_2d_axis_negative_1_expanded_ver18_model", - "node_test_layer_normalization_2d_axis_negative_2_expanded_model", - "node_test_layer_normalization_2d_axis_negative_2_expanded_ver18_model", - "node_test_layer_normalization_3d_axis0_epsilon_expanded_model", - "node_test_layer_normalization_3d_axis0_epsilon_expanded_ver18_model", - "node_test_layer_normalization_3d_axis1_epsilon_expanded_model", - "node_test_layer_normalization_3d_axis1_epsilon_expanded_ver18_model", - "node_test_layer_normalization_3d_axis2_epsilon_expanded_model", - "node_test_layer_normalization_3d_axis2_epsilon_expanded_ver18_model", - "node_test_layer_normalization_3d_axis_negative_1_epsilon_expanded_model", - "node_test_layer_normalization_3d_axis_negative_1_epsilon_expanded_ver18_model", - "node_test_layer_normalization_3d_axis_negative_2_epsilon_expanded_model", - "node_test_layer_normalization_3d_axis_negative_2_epsilon_expanded_ver18_model", - "node_test_layer_normalization_3d_axis_negative_3_epsilon_expanded_model", - "node_test_layer_normalization_3d_axis_negative_3_epsilon_expanded_ver18_model", - "node_test_layer_normalization_4d_axis0_expanded_model", - "node_test_layer_normalization_4d_axis0_expanded_ver18_model", - "node_test_layer_normalization_4d_axis1_expanded_model", - "node_test_layer_normalization_4d_axis1_expanded_ver18_model", - "node_test_layer_normalization_4d_axis2_expanded_model", - "node_test_layer_normalization_4d_axis2_expanded_ver18_model", - "node_test_layer_normalization_4d_axis3_expanded_model", - "node_test_layer_normalization_4d_axis3_expanded_ver18_model", - "node_test_layer_normalization_4d_axis_negative_1_expanded_model", - "node_test_layer_normalization_4d_axis_negative_1_expanded_ver18_model", - "node_test_layer_normalization_4d_axis_negative_2_expanded_model", - "node_test_layer_normalization_4d_axis_negative_2_expanded_ver18_model", - "node_test_layer_normalization_4d_axis_negative_3_expanded_model", - "node_test_layer_normalization_4d_axis_negative_3_expanded_ver18_model", - "node_test_layer_normalization_4d_axis_negative_4_expanded_model", - "node_test_layer_normalization_4d_axis_negative_4_expanded_ver18_model", - "node_test_layer_normalization_default_axis_expanded_model", - "node_test_layer_normalization_default_axis_expanded_ver18_model", "node_test_loop11_model", "node_test_loop13_seq_model", "node_test_loop16_seq_none_model",