From d81747eadbbdc0f97b64dd2964aedb6497de4435 Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Tue, 27 Feb 2024 11:02:05 +0530 Subject: [PATCH] [MLIR][TORCH] Extend support for OnnxToLinalg lowering for Dropout and Div op (#2938) Fixes https://github.com/nod-ai/SHARK-Turbine/issues/451, https://github.com/nod-ai/SHARK-Turbine/issues/452 --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 2 + .../TorchOnnxToTorch/DefaultDomainAtoF.cpp | 38 ++++++++++++--- .../TorchToLinalg/TensorScalarInterop.cpp | 20 +++++--- .../TorchToLinalg/Uncategorized.cpp | 14 ++++-- lib/Dialect/Torch/IR/TorchOps.cpp | 32 +++++++++++++ projects/pt1/e2e_testing/xfail_sets.py | 9 +++- .../build_tools/torch_ods_gen.py | 4 +- .../torch_mlir_e2e_test/test_suite/basic.py | 44 ++++++++++++++++++ .../test_suite/elementwise.py | 46 +++++++++++++++++++ .../TorchOnnxToTorch/simple_ops_a_to_f.mlir | 9 ++++ test/Dialect/Torch/canonicalize.mlir | 46 +++++++++++++++++++ 11 files changed, 243 insertions(+), 21 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index dc1203de9471..57b15ed18f4e 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -11231,6 +11231,7 @@ def Torch_AtenIntImplicitOp : Torch_Op<"aten.IntImplicit", [ printDefaultTorchOp(printer, *this, 1, 1); } }]; + let hasCanonicalizer = 1; } def Torch_AtenFloatImplicitOp : Torch_Op<"aten.FloatImplicit", [ @@ -11254,6 +11255,7 @@ def Torch_AtenFloatImplicitOp : Torch_Op<"aten.FloatImplicit", [ printDefaultTorchOp(printer, *this, 1, 1); } }]; + let hasCanonicalizer = 1; } def Torch_AtenTensorFloatOp : Torch_Op<"aten.tensor.float", [ diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index 99a3985a2993..1c356db890db 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -1339,12 +1339,38 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( Value ratio, trainingMode; if (numOperands == 3) { ratio = rewriter.create(loc, operands[1]); - Value trainingModeScalar = - rewriter.create(loc, operands[2]); - Value cstOne = rewriter.create( - loc, rewriter.getI64IntegerAttr(1)); - trainingMode = rewriter.create( - loc, trainingModeScalar, cstOne); + Value trainVal = operands[2]; + auto trainTensorType = + trainVal.getType().dyn_cast(); + if (!trainTensorType) + return rewriter.notifyMatchFailure(binder.op, + "train tensor must have a type"); + + Type inputDtype = trainTensorType.getOptionalDtype(); + if (!inputDtype || !inputDtype.isInteger(1)) + return rewriter.notifyMatchFailure( + binder.op, + "train tensor must have an integer dtype of width 1"); + + std::optional inputRank = Torch::getTensorRank(trainVal); + if (!inputRank || *inputRank != 0) + return rewriter.notifyMatchFailure(binder.op, + "train tensor must have rank 0"); + + if (auto valueTensorLiteralOp = + trainVal.getDefiningOp()) { + auto val = valueTensorLiteralOp.getValue() + .cast() + .getSplatValue(); + trainingMode = rewriter.create(loc, val); + } else { + Value trainingModeScalar = + rewriter.create(loc, operands[2]); + Value cstOne = rewriter.create( + loc, rewriter.getI64IntegerAttr(1)); + trainingMode = rewriter.create( + loc, trainingModeScalar, cstOne); + } } else if (numOperands == 2) { ratio = rewriter.create(loc, operands[1]); trainingMode = rewriter.create(loc, false); diff --git a/lib/Conversion/TorchToLinalg/TensorScalarInterop.cpp b/lib/Conversion/TorchToLinalg/TensorScalarInterop.cpp index a1e8e5fb72d9..58e6daa9bca8 100644 --- a/lib/Conversion/TorchToLinalg/TensorScalarInterop.cpp +++ b/lib/Conversion/TorchToLinalg/TensorScalarInterop.cpp @@ -191,12 +191,14 @@ class ConvertPrimNumToTensorScalarOp } // namespace namespace { -class ConvertAtenScalarImplicitOp - : public OpConversionPattern { +// Converts a tensor with one element to a scalar value. +template +class ConvertAtenImplicitLikeOp : public OpConversionPattern { public: - using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(AtenScalarImplicitOp op, OpAdaptor adaptor, + matchAndRewrite(OpTy op, + typename OpConversionPattern::OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp(op, adaptor.getA()); return success(); @@ -224,6 +226,12 @@ void mlir::torch::torch_to_linalg:: patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); - patterns.add(typeConverter, context); - target.addIllegalOp(); + patterns.add>(typeConverter, + context); + patterns.add>(typeConverter, + context); + patterns.add>(typeConverter, + context); + target.addIllegalOp(); } diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index ed6883000cf9..87163fc95c4a 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -725,13 +725,17 @@ static Value createLinalgPayloadCalculationForElementwiseOp( Type dtype = converter->convertType(div.getType()) .cast() .getElementType(); - if (!dtype.isa()) { - div.emitError("unimplemented: non-floating point dtype"); - return nullptr; - } Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); - return b.create(loc, lhs, rhs); + if (dtype.isa()) + return b.create(loc, lhs, rhs); + else if (dtype.isa()) { + if (dtype.isUnsignedInteger()) + return b.create(loc, lhs, rhs); + return b.create(loc, lhs, rhs); + } + div.emitError("unimplemented: non-floating point and non-integer dtype"); + return nullptr; } if (auto divTensorMode = dyn_cast(op)) { AtenDivTensorModeOp::Adaptor adaptor(operands); diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index ef3098eb1c12..2f0884b1344e 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -1568,6 +1568,38 @@ void AtenScalarImplicitOp::getCanonicalizationPatterns( }); } +//===----------------------------------------------------------------------===// +// AtenFloatImplicitOp +//===----------------------------------------------------------------------===// +void AtenFloatImplicitOp::getCanonicalizationPatterns( + RewritePatternSet &patterns, MLIRContext *context) { + patterns.add(+[](AtenFloatImplicitOp op, PatternRewriter &rewriter) { + Location loc = op.getLoc(); + Value a = op.getA(); + Value scalarValue = getScalarFloatValue(a, loc, rewriter); + if (!scalarValue) + return failure(); + rewriter.replaceOp(op, scalarValue); + return success(); + }); +} + +//===----------------------------------------------------------------------===// +// AtenIntImplicitOp +//===----------------------------------------------------------------------===// +void AtenIntImplicitOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) { + patterns.add(+[](AtenIntImplicitOp op, PatternRewriter &rewriter) { + Location loc = op.getLoc(); + Value a = op.getA(); + Value scalarValue = getScalarIntValue(a, loc, rewriter); + if (!scalarValue) + return failure(); + rewriter.replaceOp(op, scalarValue); + return success(); + }); +} + //===----------------------------------------------------------------------===// // AtenSizeOp //===----------------------------------------------------------------------===// diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 70f26fe421e0..82c1a0759e2e 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -335,6 +335,9 @@ # Dynamo not supporting conv_tbc "ConvTbcModule_basic", + + "FloatImplicitModule_basic", + "IntImplicitModule_basic", } TORCHDYNAMO_CRASHING_SET = { @@ -989,6 +992,8 @@ "ElementwiseCloneContiguousModule_basic", "ElementwiseCloneModule_basic", "ElementwiseDivScalarModule_basic", + "ElementwiseDivTensorIntegerModule_basic", + "ElementwiseDivTensorUnsignedIntegerModule_basic", "ElementwiseEluModule_basic", "ElementwiseEluNonDefaultModule_basic", "ElementwiseEqBoolScalarModule_basic", @@ -2146,8 +2151,6 @@ "ElementwiseSigmoidIntModule_basic", # Failure - unknown - "ChunkListUnpackUneven_Module_basic", - "ChunkListUnpack_Module_basic", "Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier", "CopyWithDifferentDTypesAndSizesModule_basic", "CopyWithDifferentDTypesModule_basic", @@ -2168,6 +2171,8 @@ "ReduceMinAlongDimUnsignedInt_basic", "TensorsStackNegativeDimModule_basic", "TensorsStackPromoteDTypeModule_basic", + "FloatImplicitModule_basic", + "IntImplicitModule_basic", } ONNX_CRASHING_SET = { } diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index cc41a99be228..c81f543b5dc9 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -669,8 +669,8 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::gather : (Tensor, int, Tensor, bool) -> (Tensor)") emit_with_mutating_variants("aten::scatter_add : (Tensor, int, Tensor, Tensor) -> (Tensor)") emit_with_mutating_variants("aten::scatter_reduce.two : (Tensor, int, Tensor, Tensor, str, bool) -> (Tensor)") - emit("aten::IntImplicit : (Tensor) -> (int)") - emit("aten::FloatImplicit : (Tensor) -> (float)") + emit("aten::IntImplicit : (Tensor) -> (int)", has_canonicalizer=True) + emit("aten::FloatImplicit : (Tensor) -> (float)", has_canonicalizer=True) emit("aten::tensor.float : (float, int?, Device?, bool) -> (Tensor)") emit("aten::Int.Tensor : (Tensor) -> (int)", has_folder=True) emit("aten::Float.Tensor : (Tensor) -> (float)", has_folder=True) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py index 7e707893911a..c5ef92d41637 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py @@ -3719,6 +3719,50 @@ def ScalarImplicitIntModule_basic(module, tu: TestUtils): module.forward(tu.randint(low=-100, high=100)) +# ============================================================================== + + +class FloatImplicitModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([], torch.float64, True), + ]) + def forward(self, x): + return float(torch.ops.aten.FloatImplicit(x)) + + +@register_test_case(module_factory=lambda: FloatImplicitModule()) +def FloatImplicitModule_basic(module, tu: TestUtils): + module.forward(tu.rand().double()) + + +# ============================================================================== + + +class IntImplicitModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([], torch.int64, True), + ]) + def forward(self, x): + return float(torch.ops.aten.IntImplicit(x)) + + +@register_test_case(module_factory=lambda: IntImplicitModule()) +def IntImplicitModule_basic(module, tu: TestUtils): + module.forward(tu.randint()) + + # ============================================================================== class PowIntFloat(torch.nn.Module): diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index ad4abd9f1752..611effdb2338 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -2595,6 +2595,52 @@ def ElementwiseDivTensorFloatModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseDivTensorIntegerModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int64, True), + ([-1, -1], torch.int32, True), + ]) + def forward(self, a, b): + return torch.div(a, b) + + +@register_test_case(module_factory=lambda: ElementwiseDivTensorIntegerModule()) +def ElementwiseDivTensorIntegerModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, low=-10, high=10), tu.randint(3, 4, low=-10, high=10).type(torch.int32)) + + +# ============================================================================== + + +class ElementwiseDivTensorUnsignedIntegerModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.uint8, True), + ([-1, -1], torch.uint8, True), + ]) + def forward(self, a, b): + return torch.div(a, b) + + +@register_test_case(module_factory=lambda: ElementwiseDivTensorUnsignedIntegerModule()) +def ElementwiseDivTensorUnsignedIntegerModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, low=0, high=10).to(torch.uint8), tu.randint(3, 4, low=0, high=10).type(torch.uint8)) + + +# ============================================================================== + + class ElementwiseDivRoundingModeTruncModule(torch.nn.Module): def __init__(self): diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir index 525583b7660e..7dc262228f1a 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir @@ -706,6 +706,15 @@ func.func @test_div(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3 // ----- +// CHECK-LABEL: @test_div_int32 +func.func @test_div_int32(%arg0: !torch.vtensor<[3,4,5],si32>, %arg1: !torch.vtensor<[3,4,5],si32>) -> !torch.vtensor<[3,4,5],si32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.div.Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5],si32>, !torch.vtensor<[3,4,5],si32> -> !torch.vtensor<[3,4,5],si32> + %0 = torch.operator "onnx.Div"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],si32>, !torch.vtensor<[3,4,5],si32>) -> !torch.vtensor<[3,4,5],si32> + return %0 : !torch.vtensor<[3,4,5],si32> +} + +// ----- + // CHECK-LABEL: @test_div_uint8 func.func @test_div_uint8(%arg0: !torch.vtensor<[3,4,5],ui8>, %arg1: !torch.vtensor<[3,4,5],ui8>) -> !torch.vtensor<[3,4,5],ui8> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.div.Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5],ui8>, !torch.vtensor<[3,4,5],ui8> -> !torch.vtensor<[3,4,5],ui8> diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index 4df52cfb174b..85b95eb1cdba 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -2145,6 +2145,52 @@ func.func @torch.aten.ScalarImplicit$canonicalize_literal_0d() -> !torch.number return %1 : !torch.number } +// ----- + +// CHECK-LABEL: func.func @torch.aten.FloatImplicit$canonicalize_numtotensor_0d() -> !torch.float { +// CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00 +// CHECK: return %[[FLOAT1]] : !torch.float +func.func @torch.aten.FloatImplicit$canonicalize_numtotensor_0d() -> !torch.float { + %float1 = torch.constant.float 1.0 + %0 = torch.prim.NumToTensor.Scalar %float1 : !torch.float -> !torch.vtensor<[],f64> + %1 = torch.aten.FloatImplicit %0 : !torch.vtensor<[],f64> -> !torch.float + return %1 : !torch.float +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.FloatImplicit$canonicalize_literal_0d() -> !torch.float { +// CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00 +// CHECK: return %[[FLOAT1]] : !torch.float +func.func @torch.aten.FloatImplicit$canonicalize_literal_0d() -> !torch.float { + %0 = torch.vtensor.literal(dense<1.0> : tensor) : !torch.vtensor<[],f64> + %1 = torch.aten.FloatImplicit %0 : !torch.vtensor<[],f64> -> !torch.float + return %1 : !torch.float +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.IntImplicit$canonicalize_numtotensor_0d() -> !torch.int { +// CHECK: %[[INT1:.*]] = torch.constant.int 1 +// CHECK: return %[[INT1]] : !torch.int +func.func @torch.aten.IntImplicit$canonicalize_numtotensor_0d() -> !torch.int { + %int1 = torch.constant.int 1 + %0 = torch.prim.NumToTensor.Scalar %int1 : !torch.int -> !torch.vtensor<[],si64> + %1 = torch.aten.IntImplicit %0 : !torch.vtensor<[],si64> -> !torch.int + return %1 : !torch.int +} + +// CHECK-LABEL: func.func @torch.aten.IntImplicit$canonicalize_literal_0d() -> !torch.int { +// CHECK: %[[INT1:.*]] = torch.constant.int 1 +// CHECK: return %[[INT1]] : !torch.int +func.func @torch.aten.IntImplicit$canonicalize_literal_0d() -> !torch.int { + %0 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> + %1 = torch.aten.IntImplicit %0 : !torch.vtensor<[],si64> -> !torch.int + return %1 : !torch.int +} + +// ----- + // CHECK-LABEL: func.func @torch.prims.view_of$fold( // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[3,4,2],f32>) -> !torch.vtensor<[3,4,2],f32> { // CHECK-NEXT: return %[[ARG]] : !torch.vtensor<[3,4,2],f32>