diff --git a/src/Conversion/ONNXToKrnl/Tensor/Constant.cpp b/src/Conversion/ONNXToKrnl/Tensor/Constant.cpp index d4fa675dfe..7a0d1acdbf 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/Constant.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/Constant.cpp @@ -12,6 +12,20 @@ using namespace mlir; +bool checkOpResultIsReturned(ONNXConstantOp *constantOp) { + FuncOp function = getContainingFunction(constantOp->getOperation()); + + bool opIsReturned = false; + function.walk([&opIsReturned, constantOp](ReturnOp op) { + auto result = constantOp->getResult(); + for (const auto &operand : op.getOperands()) + if (operand == result) + opIsReturned = true; + }); + + return opIsReturned; +} + struct ONNXConstantOpLowering : public ConversionPattern { static int constantID; @@ -47,9 +61,30 @@ struct ONNXConstantOpLowering : public ConversionPattern { // Increment constant ID: constantID++; - // Replace this operation with the generated alloc. - // rewriter.replaceOp(op, alloc); - rewriter.replaceOp(op, constantGlobal.getResult()); + // Check if the variable is returned. + if (checkOpResultIsReturned(&constantOp)) { + // In this case, use an AllocOp for the constant since krnl.Global + // operations are not mean to be returned. + AllocOp alloc = rewriter.create(loc, memRefType); + + // Compute size in bytes using the input tensor. + Value tensorSize = emitConstantOp(rewriter, loc, + rewriter.getIntegerType(64), getMemRefEltSizeInBytes(memRefType)); + auto numElementsValue = emitConstantOp( + rewriter, loc, rewriter.getIntegerType(64), numElements); + tensorSize = rewriter.create(loc, tensorSize, numElementsValue); + + // Copy the value in the AllocOp. + rewriter.create( + loc, alloc, constantGlobal.getResult(), tensorSize); + + // Since the value is returned we need to only work with the AllocOp + // not the KrnlGlobalOp. Globals cannot be returned. + rewriter.replaceOp(op, alloc.getResult()); + } else { + // Replace this operation with the generated krnl.global. + rewriter.replaceOp(op, constantGlobal.getResult()); + } return success(); } diff --git a/test/backend/test.py b/test/backend/test.py index a44eb84134..d1c448b58b 100644 --- a/test/backend/test.py +++ b/test/backend/test.py @@ -570,8 +570,8 @@ # Size # TODO(tjingrant): fix unit test for size ops. - # "test_size_cpu": (test_static,), - # "test_size_example_cpu": (test_static,), + "test_size_cpu": (test_static,), + "test_size_example_cpu": (test_static,), # Slice (makes Axis a runtime argument, which is not supported). diff --git a/test/mlir/krnl/constant.mlir b/test/mlir/krnl/constant.mlir index 5790529f2d..53f13b811c 100644 --- a/test/mlir/krnl/constant.mlir +++ b/test/mlir/krnl/constant.mlir @@ -2,14 +2,33 @@ // ----- -func @test_constant(%arg0 : tensor<1xf32>) -> tensor<*xf32> { +func @test_constant(%arg0 : tensor<3x2xf32>) -> tensor<*xf32> { %0 = "onnx.Constant"() {value = dense<[[0.0, 0.0], [1.0, 1.1], [2.0, 2.1]]> : tensor<3x2xf32>} : () -> tensor<*xf32> - "std.return"(%0) : (tensor<*xf32>) -> () + %1 = "onnx.Relu"(%0) : (tensor<*xf32>) -> tensor<*xf32> + "std.return"(%1) : (tensor<*xf32>) -> () // CHECK: llvm.func @llvm.memcpy.p0i8.p0i8.i64(!llvm.ptr, !llvm.ptr, !llvm.i64, !llvm.i1) // CHECK: llvm.mlir.global internal constant [[GLOBAL_CONST:@.+]](dense<{{.*}}[0.000000e+00, 0.000000e+00], [1.000000e+00, 1.100000e+00], [2.000000e+00, 2.100000e+00]{{.*}}> : tensor<3x2xf32>) : !llvm.array<3 x array<2 x float>> // CHECK: llvm.func @test_constant({{.*}}) -> !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> { + // CHECK: [[CONST_3:%.+]] = llvm.mlir.constant(3 : index) : !llvm.i64 + // CHECK: [[CONST_4:%.+]] = llvm.mlir.constant(2 : index) : !llvm.i64 + + /// This is the result MemRef: + // CHECK: [[MALLOC_FOR_RES:%.+]] = llvm.call @malloc + // CHECK: [[CAST_MALLOC_FOR_RES:%.+]] = llvm.bitcast [[MALLOC_FOR_RES]] : !llvm.ptr to !llvm.ptr + // CHECK: [[RES_MEMREF:%.+]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: [[RES_MEMREF_1:%.+]] = llvm.insertvalue [[CAST_MALLOC_FOR_RES]], [[RES_MEMREF]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: [[RES_MEMREF_2:%.+]] = llvm.insertvalue [[CAST_MALLOC_FOR_RES]], [[RES_MEMREF_1]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: [[CONST_0:%.+]] = llvm.mlir.constant(0 : index) : !llvm.i64 + // CHECK: [[RES_MEMREF_3:%.+]] = llvm.insertvalue [[CONST_0]], [[RES_MEMREF_2]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: [[CONST_1:%.+]] = llvm.mlir.constant(1 : index) : !llvm.i64 + // CHECK: [[CONST_2:%.+]] = llvm.mlir.constant(2 : index) : !llvm.i64 + // CHECK: [[RES_MEMREF_4:%.+]] = llvm.insertvalue [[CONST_3]], [[RES_MEMREF_3]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: [[RES_MEMREF_5:%.+]] = llvm.insertvalue [[CONST_2]], [[RES_MEMREF_4]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: [[RES_MEMREF_6:%.+]] = llvm.insertvalue [[CONST_4]], [[RES_MEMREF_5]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: [[RES_MEMREF_7:%.+]] = llvm.insertvalue [[CONST_1]], [[RES_MEMREF_6]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: [[CONST1:%.+]] = llvm.mlir.constant(1 : i64) : !llvm.i64 // CHECK: [[ALLOCA:%.+]] = llvm.alloca [[CONST1]] x !llvm.array<3 x array<2 x float>> : (!llvm.i64) -> !llvm.ptr>> // CHECK: [[I8ALLOCA:%.+]] = llvm.bitcast [[ALLOCA]] : !llvm.ptr>> to !llvm.ptr @@ -51,5 +70,82 @@ func @test_constant(%arg0 : tensor<1xf32>) -> tensor<*xf32> { // CHECK: [[CONST1:%.+]] = llvm.mlir.constant(1 : index) : !llvm.i64 // CHECK: [[MEMREF5:%.+]] = llvm.insertvalue [[CONST1]], [[MEMREF4]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - // CHECK: llvm.return [[MEMREF5]] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: llvm.return [[RES_MEMREF_7]] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +} + +// ----- + +func @test_constant(%arg0 : tensor<3x2xf32>) -> tensor<*xf32> { + %0 = "onnx.Constant"() {value = dense<[[0.0, 0.0], [1.0, 1.1], [2.0, 2.1]]> : tensor<3x2xf32>} : () -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK: [[CONST1:%.+]] = llvm.mlir.constant(1 : i64) : !llvm.i64 + // CHECK: [[ALLOCA:%.+]] = llvm.alloca [[CONST1]] x !llvm.array<3 x array<2 x float>> : (!llvm.i64) -> !llvm.ptr>> + // CHECK: [[I8ALLOCA:%.+]] = llvm.bitcast [[ALLOCA]] : !llvm.ptr>> to !llvm.ptr + + // CHECK: [[GLOBAL_ADDR:%.+]] = llvm.mlir.addressof [[GLOBAL_CONST]] : !llvm.ptr>> + // CHECK: [[I8GLOBAL:%.+]] = llvm.bitcast [[GLOBAL_ADDR]] : !llvm.ptr>> to !llvm.ptr + + /// Size of the constant tensor in bytes. + // CHECK: [[CONST4:%.+]] = llvm.mlir.constant(4 : i64) : !llvm.i64 + // CHECK: [[CONST6:%.+]] = llvm.mlir.constant(6 : i64) : !llvm.i64 + // CHECK: [[CONST_MUL1:%.+]] = llvm.mul [[CONST4]], [[CONST6]] : !llvm.i64 + // CHECK: [[GLOBAL_SIZE_BYTES:%.+]] = llvm.sext [[CONST_MUL1]] : !llvm.i64 to !llvm.i64 + + /// Volatile flag + // CHECK: [[CONST0:%.+]] = llvm.mlir.constant(false) : !llvm.i1 + + // CHECK: llvm.call @llvm.memcpy.p0i8.p0i8.i64([[I8ALLOCA]], [[I8GLOBAL]], [[GLOBAL_SIZE_BYTES]], [[CONST0]]) : (!llvm.ptr, !llvm.ptr, !llvm.i64, !llvm.i1) -> () + + /// Prepare data for MemRef insertion. + // CHECK: [[TYPED_ALLOCA:%.+]] = llvm.bitcast [[ALLOCA]] : !llvm.ptr>> to !llvm.ptr + + /// Insert the constant value in the local MemRef. + // CHECK: [[LOCAL_MEMREF:%.+]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: [[LOCAL_MEMREF0:%.+]] = llvm.insertvalue [[TYPED_ALLOCA]], [[LOCAL_MEMREF]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: [[LOCAL_MEMREF1:%.+]] = llvm.insertvalue [[TYPED_ALLOCA]], [[LOCAL_MEMREF0]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + + /// Insert offset. + // CHECK: [[CONST00:%.+]] = llvm.mlir.constant(0 : index) : !llvm.i64 + // CHECK: [[MEMREF1:%.+]] = llvm.insertvalue [[CONST00]], [[LOCAL_MEMREF1]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + + /// Insert sizes and strides. + // CHECK: [[CONST3:%.+]] = llvm.mlir.constant(3 : index) : !llvm.i64 + // CHECK: [[MEMREF2:%.+]] = llvm.insertvalue [[CONST3]], [[MEMREF1]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: [[CONST1:%.+]] = llvm.mlir.constant(2 : index) : !llvm.i64 + // CHECK: [[MEMREF3:%.+]] = llvm.insertvalue [[CONST1]], [[MEMREF2]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + + // CHECK: [[CONST2:%.+]] = llvm.mlir.constant(2 : index) : !llvm.i64 + // CHECK: [[MEMREF4:%.+]] = llvm.insertvalue [[CONST2]], [[MEMREF3]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: [[CONST1:%.+]] = llvm.mlir.constant(1 : index) : !llvm.i64 + // CHECK: [[MEMREF5:%.+]] = llvm.insertvalue [[CONST1]], [[MEMREF4]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + + // CHECK: [[CONST_3:%.+]] = llvm.mlir.constant(3 : index) : !llvm.i64 + // CHECK: [[CONST_4:%.+]] = llvm.mlir.constant(2 : index) : !llvm.i64 + + /// This is the result MemRef: + // CHECK: [[MALLOC_FOR_RES:%.+]] = llvm.call @malloc + // CHECK: [[CAST_MALLOC_FOR_RES:%.+]] = llvm.bitcast [[MALLOC_FOR_RES]] : !llvm.ptr to !llvm.ptr + // CHECK: [[RES_MEMREF:%.+]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: [[RES_MEMREF_1:%.+]] = llvm.insertvalue [[CAST_MALLOC_FOR_RES]], [[RES_MEMREF]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: [[RES_MEMREF_2:%.+]] = llvm.insertvalue [[CAST_MALLOC_FOR_RES]], [[RES_MEMREF_1]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: [[CONST_0:%.+]] = llvm.mlir.constant(0 : index) : !llvm.i64 + // CHECK: [[RES_MEMREF_3:%.+]] = llvm.insertvalue [[CONST_0]], [[RES_MEMREF_2]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: [[CONST_1:%.+]] = llvm.mlir.constant(1 : index) : !llvm.i64 + // CHECK: [[CONST_2:%.+]] = llvm.mlir.constant(2 : index) : !llvm.i64 + // CHECK: [[RES_MEMREF_4:%.+]] = llvm.insertvalue [[CONST_3]], [[RES_MEMREF_3]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: [[RES_MEMREF_5:%.+]] = llvm.insertvalue [[CONST_2]], [[RES_MEMREF_4]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: [[RES_MEMREF_6:%.+]] = llvm.insertvalue [[CONST_4]], [[RES_MEMREF_5]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: [[RES_MEMREF_7:%.+]] = llvm.insertvalue [[CONST_1]], [[RES_MEMREF_6]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + + /// Copy result in a MemRef: + // CHECK: [[CONST_5:%.+]] = llvm.mlir.constant(24 : i64) : !llvm.i64 + // CHECK: [[OUT_DATA:%.+]] = llvm.extractvalue [[RES_MEMREF_7]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: [[TYPED_OUT_DATA:%.+]] = llvm.bitcast [[OUT_DATA]] : !llvm.ptr to !llvm.ptr + // CHECK: [[GLOBAL_DATA:%.+]] = llvm.extractvalue [[MEMREF5]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: [[TYPED_GLOBAL_DATA:%.+]] = llvm.bitcast [[GLOBAL_DATA]] : !llvm.ptr to !llvm.ptr + // CHECK: [[EXTENDED_CONST_5:%.+]] = llvm.sext [[CONST_5]] : !llvm.i64 to !llvm.i64 + // CHECK: [[FALSE:%.+]] = llvm.mlir.constant(false) : !llvm.i1 + // CHECK: llvm.call @llvm.memcpy.p0i8.p0i8.i64([[TYPED_OUT_DATA]], [[TYPED_GLOBAL_DATA]], [[EXTENDED_CONST_5]], [[FALSE]]) : (!llvm.ptr, !llvm.ptr, !llvm.i64, !llvm.i1) -> () + // CHECK: llvm.return [[RES_MEMREF_7]] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> } diff --git a/test/mlir/onnx/onnx_lowering.mlir b/test/mlir/onnx/onnx_lowering.mlir index 9da5ca5233..5b4eb4bf88 100644 --- a/test/mlir/onnx/onnx_lowering.mlir +++ b/test/mlir/onnx/onnx_lowering.mlir @@ -13,8 +13,13 @@ func @test_no_argument_2() -> tensor<*xf32> { // CHECK: test_no_argument_1 // CHECK-NEXT: test_no_argument_2 -// CHECK: [[RES:%.+]] = "{{.*}}"({{.*}}) {{.*}} : ({{.*}}) -> memref<2x2xf32> -// CHECK: return [[RES]] : memref<2x2xf32> +// CHECK: [[GLOBAL:%.+]] = "{{.*}}"({{.*}}) {{.*}} : ({{.*}}) -> memref<2x2xf32> +// CHECK: [[ALLOC:%.+]] = alloc() : memref<2x2xf32> +// CHECK: [[CONST_4:%.+]] = constant 4 : i64 +// CHECK: [[CONST_4_0:%.+]] = constant 4 : i64 +// CHECK: [[SIZE:%.+]] = muli [[CONST_4]], [[CONST_4_0]] : i64 +// CHECK: "krnl.memcpy"([[ALLOC]], [[GLOBAL]], [[SIZE]]) : (memref<2x2xf32>, memref<2x2xf32>, i64) -> () +// CHECK: return [[ALLOC]] : memref<2x2xf32> // ----- @@ -1666,8 +1671,13 @@ func @test_constant_dense_2d_value(%arg0: tensor<1xf32>) -> tensor<*xf32> { %0 = "onnx.Constant"() {value = dense<[[0.0, 0.0], [1.0, 1.1], [2.0, 2.1]]> : tensor<3x2xf32>} : () -> tensor<*xf32> "std.return"(%0) : (tensor<*xf32>) -> () // CHECK-LABEL: test_constant_dense_2d_value - // CHECK: [[RES:%.+]] = "krnl.global"() {name = "constant_0", shape = [3, 2], value = dense<{{.*}}[0.000000e+00, 0.000000e+00], [1.000000e+00, 1.100000e+00], [2.000000e+00, 2.100000e+00]{{.*}}> : tensor<3x2xf32>} : () -> memref<3x2xf32> - // CHECK: return [[RES]] : memref<3x2xf32> + // CHECK: [[GLOBAL:%.+]] = "krnl.global"() {name = "constant_0", shape = [3, 2], value = dense<{{.*}}[0.000000e+00, 0.000000e+00], [1.000000e+00, 1.100000e+00], [2.000000e+00, 2.100000e+00]{{.*}}> : tensor<3x2xf32>} : () -> memref<3x2xf32> + // CHECK: [[ALLOC:%.+]] = alloc() : memref<3x2xf32> + // CHECK: [[CONST_4:%.+]] = constant 4 : i64 + // CHECK: [[CONST_6:%.+]] = constant 6 : i64 + // CHECK: [[SIZE:%.+]] = muli [[CONST_4]], [[CONST_6]] : i64 + // CHECK: "krnl.memcpy"([[ALLOC]], [[GLOBAL]], [[SIZE]]) : (memref<3x2xf32>, memref<3x2xf32>, i64) -> () + // CHECK: return [[ALLOC]] : memref<3x2xf32> } // -----