From fd65a66d7e0f348a5563f7a796c0969c61130743 Mon Sep 17 00:00:00 2001 From: Praveen G <73869424+praveen-g-ctt@users.noreply.github.com> Date: Wed, 5 Feb 2025 11:56:05 +0530 Subject: [PATCH] [torch-mlir] Support lowering of aten constraint ops (#3943) 1. aten::sym_constrain_range 2. aten::sym_constrain_range_for_size 3. aten::_assert_scalar --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 71 +++++++++++++++++ .../TorchToLinalg/Uncategorized.cpp | 66 ++++++++++++++++ .../Torch/Transforms/DecomposeComplexOps.cpp | 78 +++++++++++++++++++ projects/pt1/e2e_testing/xfail_sets.py | 12 ++- .../build_tools/torch_ods_gen.py | 5 ++ .../torch_mlir_e2e_test/test_suite/basic.py | 59 ++++++++++++++ .../Conversion/TorchToLinalg/constraints.mlir | 30 +++++++ test/Dialect/Torch/decompose-complex-ops.mlir | 50 ++++++++++++ 8 files changed, 370 insertions(+), 1 deletion(-) create mode 100644 test/Conversion/TorchToLinalg/constraints.mlir diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 2d71d0d8fe3d..c5a31a3d2fb2 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -17771,6 +17771,77 @@ def Torch_Aten_MakePerTensorQuantizedTensorOp : Torch_Op<"aten._make_per_tensor_ }]; } +def Torch_AtenSymConstrainRangeOp : Torch_Op<"aten.sym_constrain_range", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::sym_constrain_range : (Scalar, int?, int?) -> ()`"; + let arguments = (ins + AnyTorchScalarType:$size, + AnyTorchOptionalIntType:$min, + AnyTorchOptionalIntType:$max + ); + let results = (outs + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenSymConstrainRangeOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 0); + } + void AtenSymConstrainRangeOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 0); + } + }]; +} + +def Torch_AtenSymConstrainRangeForSizeOp : Torch_Op<"aten.sym_constrain_range_for_size", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::sym_constrain_range_for_size : (Scalar, int?, int?) -> ()`"; + let arguments = (ins + AnyTorchScalarType:$size, + AnyTorchOptionalIntType:$min, + AnyTorchOptionalIntType:$max + ); + let results = (outs + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenSymConstrainRangeForSizeOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 0); + } + void AtenSymConstrainRangeForSizeOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 0); + } + }]; +} + +def Torch_Aten_AssertScalarOp : Torch_Op<"aten._assert_scalar", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::_assert_scalar : (Scalar, str) -> ()`"; + let arguments = (ins + AnyTorchScalarType:$self, + Torch_StringType:$assert_msg + ); + let results = (outs + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult Aten_AssertScalarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 0); + } + void Aten_AssertScalarOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 0); + } + }]; +} + def Torch_PrimLayoutOp : Torch_Op<"prim.layout", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index c83f49d7f62d..4ebdfbf94129 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -21,10 +21,12 @@ #include "torch-mlir/Conversion/TorchToLinalg/Utils.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" +#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" #include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "llvm/ADT/APSInt.h" #include +#include #include using namespace mlir; @@ -3564,6 +3566,68 @@ class ConvertAtenPolarOp : public OpConversionPattern { }; } // namespace +namespace { +class ConvertSymConstrainRangeOp + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenSymConstrainRangeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (failed(verifyLinalgCompatibleTypes(op, rewriter))) + return failure(); + + auto loc = op.getLoc(); + auto min = op.getMin(); + auto max = op.getMax(); + + int64_t minValue = std::numeric_limits::min(); + int64_t maxValue = std::numeric_limits::max(); + + Type operandType = getTypeConverter()->convertType(op.getSize().getType()); + + if (!isa(min.getType())) + if (!matchPattern(min, m_TorchConstantInt(&minValue))) + return rewriter.notifyMatchFailure( + op, "Expected min value to be constant integer"); + + if (!isa(max.getType())) + if (!matchPattern(max, m_TorchConstantInt(&maxValue))) + return rewriter.notifyMatchFailure( + op, "Expected max value to be constant integer"); + + if (maxValue < minValue) { + std::string errorMsg = + "Max must be greater than or equal to min, got min = " + + std::to_string(minValue) + ", max = " + std::to_string(maxValue); + return op.emitError(errorMsg); + } + + min = getConstant(rewriter, loc, minValue, operandType); + max = getConstant(rewriter, loc, maxValue, operandType); + + // Check min <= size <= max + + // FIXME:: Skip the below checks if constraint ops are already inserted as + // part of symbol expr evaluation + auto checkMin = rewriter.create( + loc, arith::CmpIPredicate::sle, min, adaptor.getSize()); + auto checkMax = rewriter.create( + loc, arith::CmpIPredicate::sle, adaptor.getSize(), max); + auto compareVal = rewriter.create(loc, checkMin, checkMax); + + std::string assertMessage = "Size constraint failed. Expected range: [" + + std::to_string(minValue) + ", " + + std::to_string(maxValue) + "]"; + rewriter.create(loc, compareVal, + rewriter.getStringAttr(assertMessage)); + + rewriter.eraseOp(op); + return success(); + } +}; +} // namespace + void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target) { @@ -3626,4 +3690,6 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality( patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); } diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 3303ec1ecc1b..1226ad2c03e2 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -11455,6 +11455,80 @@ class DecomposeAtenSpecialExpm1Op }; } // namespace +namespace { +class DecomposeAtenConstrainRangeForSizeOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenSymConstrainRangeForSizeOp op, + PatternRewriter &rewriter) const override { + + auto loc = op.getLoc(); + auto min = op.getMin(); + auto max = op.getMax(); + + int64_t minValue, maxValue; + + if (isa(min.getType())) { + // Set min value to 0 + min = rewriter.create(loc, 0); + } else { + // Check if min value is a constant + if (!matchPattern(min, m_TorchConstantInt(&minValue))) + return rewriter.notifyMatchFailure( + op, "Expected min value to be constant integer"); + } + + if (!isa(max.getType())) { + // Verify that max value is greater than 2 + if (!matchPattern(max, m_TorchConstantInt(&maxValue))) + return rewriter.notifyMatchFailure( + op, "Expected max value to be constant integer"); + + if (maxValue <= 2) { + std::string errorMsg = "Max value to constrain_range_for_size must be " + "greater than 2, got: " + + std::to_string(maxValue); + return op.emitError(errorMsg); + } + } + + rewriter.replaceOpWithNewOp(op, op.getSize(), min, + max); + return success(); + } +}; +} // namespace + +namespace { +class DecomposeAten_AssertScalarOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(Aten_AssertScalarOp op, + PatternRewriter &rewriter) const override { + + auto loc = op.getLoc(); + auto assertCond = op.getSelf(); + + if (isa(assertCond.getType())) + assertCond = rewriter.create(loc, assertCond); + else if (isa(assertCond.getType())) + assertCond = rewriter.create(loc, assertCond); + assert(isa(assertCond.getType()) && + "Unhandled type encountered in aten._assert_scalar op"); + + std::string assertMessage; + if (!matchPattern(op.getAssertMsg(), m_TorchConstantStr(assertMessage))) + return rewriter.notifyMatchFailure( + op, "Assert message must be a constant string"); + + rewriter.replaceOpWithNewOp(op, assertCond, assertMessage); + return success(); + } +}; +} // namespace + namespace { class DecomposeComplexOpsPass : public DecomposeComplexOpsBase { @@ -11753,6 +11827,10 @@ class DecomposeComplexOpsPass // Torchvision ops addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal( + patterns); + addPatternIfTargetOpIsIllegal(patterns); + GreedyRewriteConfig config; config.useTopDownTraversal = true; config.maxIterations = GreedyRewriteConfig::kNoLimit; diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 4df3d186f8ea..e433fabe2712 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -35,6 +35,10 @@ "Aten_TrilinearModuleZerodDimBug_basic", # missing lowering from aten.pow.Tensor_Tensor for integer result "PowIntIntModule_basic", + # Unknown builtin op: aten::_check_is_size in TorchScript + "AtenSymConstrainRange_basic", + "AtenSymConstrainRangeForSize_basic", + "Aten_AssertScalar_basic", } if torch_version_for_comparison() < version.parse("2.5.0.dev"): @@ -623,7 +627,6 @@ "AtenMmQMixedSigni8_basic", "AtenMmQint8_basic", "AtenMmQuint8_basic", - "AtenNonzero1DDynamicModule_basic", "AtenRealView128Module_basic", "AtenRealView64Module_basic", "AtenTopKModule_basic", @@ -941,6 +944,9 @@ "UniformModule_basic", "UniformStaticShapeModule_basic", "ScaledDotProductAttentionGQAModule_basic", + "AtenSymConstrainRange_basic", + "AtenSymConstrainRangeForSize_basic", + "Aten_AssertScalar_basic", } FX_IMPORTER_STABLEHLO_CRASHING_SET = { @@ -964,6 +970,7 @@ "Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic", "CrossEntropyLossModule_basic", "CrossEntropyLossNoReductionModule_basic", + "AtenNonzero1DDynamicModule_basic", # error: Mismatched ranks of types2 vs 1 } STABLEHLO_PASS_SET = { @@ -3254,6 +3261,9 @@ "Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic", "Aten_TrilinearModuleZerodDimBug_basic", "ScaledDotProductAttentionGQAModule_basic", + "AtenSymConstrainRange_basic", + "AtenSymConstrainRangeForSize_basic", + "Aten_AssertScalar_basic", } if torch_version_for_comparison() < version.parse("2.3.0.dev"): diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 4d7f8d52268c..350fea711bbf 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -1232,6 +1232,11 @@ def emit_with_mutating_variants(key, **kwargs): ) emit("aten::_make_per_tensor_quantized_tensor : (Tensor, float, int) -> (Tensor)") + # Constraint ops + emit("aten::sym_constrain_range : (Scalar, int?, int?) -> ()") + emit("aten::sym_constrain_range_for_size : (Scalar, int?, int?) -> ()") + emit("aten::_assert_scalar : (Scalar, str) -> ()") + # ========================================================================== # `prim::` namespace. # ========================================================================== diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py index fe8a31186807..4ba497452a76 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py @@ -6480,3 +6480,62 @@ def forward(self, x): @register_test_case(module_factory=lambda: AtenNonzero1DDynamicModule()) def AtenNonzero1DDynamicModule_basic(module, tu: TestUtils): module.forward(torch.tensor([0, 0, 1, 1, 0, 0], dtype=torch.bool)) + + +# ============================================================================== + + +class AtenSymConstrainRange(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([None, ([-1], torch.int, True)]) + def forward(self, x): + a = x.item() + torch.ops.aten.sym_constrain_range(a, max=5) + return a + + +@register_test_case(module_factory=lambda: AtenSymConstrainRange()) +def AtenSymConstrainRange_basic(module, tu: TestUtils): + module.forward(torch.tensor(4)) + + +# ============================================================================== + + +class AtenSymConstrainRangeForSize(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([None, ([-1], torch.int, True)]) + def forward(self, x): + a = x.item() + torch.ops.aten.sym_constrain_range_for_size(a, min=0, max=10) + return a + + +@register_test_case(module_factory=lambda: AtenSymConstrainRangeForSize()) +def AtenSymConstrainRangeForSize_basic(module, tu: TestUtils): + module.forward(torch.tensor(4)) + + +# ============================================================================== +class Aten_AssertScalar(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([None, ([-1], torch.int, True)]) + def forward(self, x): + a = x.item() + assert_msg = "Assertion failed for condition x.item() > 3" + torch.ops.aten._assert_scalar(a > 3, assert_msg) + return a + + +@register_test_case(module_factory=lambda: Aten_AssertScalar()) +def Aten_AssertScalar_basic(module, tu: TestUtils): + module.forward(torch.tensor(4)) diff --git a/test/Conversion/TorchToLinalg/constraints.mlir b/test/Conversion/TorchToLinalg/constraints.mlir new file mode 100644 index 000000000000..19075d72103a --- /dev/null +++ b/test/Conversion/TorchToLinalg/constraints.mlir @@ -0,0 +1,30 @@ +// RUN: torch-mlir-opt <%s -convert-torch-to-linalg -split-input-file -verify-diagnostics | FileCheck %s + +// CHECK-LABEL: func.func @torch.aten.sym_constrain_range( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.int) -> !torch.int { +// CHECK: %[[VAL_1:.*]] = torch_c.to_i64 %[[VAL_0]] +// CHECK: %[[VAL_2:.*]] = torch.constant.int 7 +// CHECK: %[[VAL_3:.*]] = torch.constant.int 0 +// CHECK: %[[VAL_4:.*]] = torch.constant.none +// CHECK: %[[VAL_5:.*]] = arith.constant 0 : i64 +// CHECK: %[[VAL_6:.*]] = arith.constant 9223372036854775807 : i64 +// CHECK: %[[VAL_7:.*]] = arith.cmpi sle, %[[VAL_5]], %[[VAL_1]] : i64 +// CHECK: %[[VAL_8:.*]] = arith.cmpi sle, %[[VAL_1]], %[[VAL_6]] : i64 +// CHECK: %[[VAL_9:.*]] = arith.andi %[[VAL_7]], %[[VAL_8]] : i1 +// CHECK: cf.assert %[[VAL_9]], "Size constraint failed. Expected range: [0, 9223372036854775807]" +// CHECK: %[[VAL_10:.*]] = arith.constant 0 : i64 +// CHECK: %[[VAL_11:.*]] = arith.constant 7 : i64 +// CHECK: %[[VAL_12:.*]] = arith.cmpi sle, %[[VAL_10]], %[[VAL_1]] : i64 +// CHECK: %[[VAL_13:.*]] = arith.cmpi sle, %[[VAL_1]], %[[VAL_11]] : i64 +// CHECK: %[[VAL_14:.*]] = arith.andi %[[VAL_12]], %[[VAL_13]] : i1 +// CHECK: cf.assert %[[VAL_14]], "Size constraint failed. Expected range: [0, 7]" +// CHECK: return %[[VAL_0]] : !torch.int +// CHECK: } +func.func @torch.aten.sym_constrain_range(%arg0: !torch.int) -> !torch.int { + %int7 = torch.constant.int 7 + %int0 = torch.constant.int 0 + %none = torch.constant.none + torch.aten.sym_constrain_range %arg0, %int0, %none : !torch.int, !torch.int, !torch.none + torch.aten.sym_constrain_range %arg0, %int0, %int7 : !torch.int, !torch.int, !torch.int + return %arg0 : !torch.int +} diff --git a/test/Dialect/Torch/decompose-complex-ops.mlir b/test/Dialect/Torch/decompose-complex-ops.mlir index 384502ecd2af..4c99f4949a38 100644 --- a/test/Dialect/Torch/decompose-complex-ops.mlir +++ b/test/Dialect/Torch/decompose-complex-ops.mlir @@ -228,3 +228,53 @@ func.func @torch.aten.fft_rfft$2d_first_dim(%arg0: !torch.vtensor<[36,23],f32>) %out = torch.aten.fft_rfft %arg0, %none, %int0, %none : !torch.vtensor<[36,23],f32>, !torch.none, !torch.int, !torch.none -> !torch.vtensor<[19,23],complex> return %out : !torch.vtensor<[19,23],complex> } + +// ----- + +// CHECK-LABEL: func.func @torch.aten.sym_constrain_range_for_size( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.int) -> !torch.int { +// CHECK: %[[VAL_1:.*]] = torch.constant.int 7 +// CHECK: %[[VAL_2:.*]] = torch.constant.int 0 +// CHECK: %[[VAL_3:.*]] = torch.constant.none +// CHECK: torch.aten.sym_constrain_range %[[VAL_0]], %[[VAL_2]], %[[VAL_3]] : !torch.int, !torch.int, !torch.none +// CHECK: torch.aten.sym_constrain_range %[[VAL_0]], %[[VAL_2]], %[[VAL_1]] : !torch.int, !torch.int, !torch.int +// CHECK: return %[[VAL_0]] : !torch.int +// CHECK: } +func.func @torch.aten.sym_constrain_range_for_size(%arg0: !torch.int) -> !torch.int { + %int7 = torch.constant.int 7 + %int0 = torch.constant.int 0 + %none = torch.constant.none + torch.aten.sym_constrain_range_for_size %arg0, %none, %none : !torch.int, !torch.none, !torch.none + torch.aten.sym_constrain_range_for_size %arg0, %int0, %int7 : !torch.int, !torch.int, !torch.int + return %arg0 : !torch.int +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten._assert_scalar( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.int) -> !torch.int { +// CHECK: %[[VAL_1:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_2:.*]] = torch.constant.int 3 +// CHECK: %[[VAL_3:.*]] = torch.aten.ge.int %[[VAL_0]], %[[VAL_2]] : !torch.int, !torch.int -> !torch.bool +// CHECK: %[[VAL_4:.*]] = torch.aten.Int.bool %[[VAL_3]] : !torch.bool -> !torch.int +// CHECK: %[[VAL_5:.*]] = torch.aten.Bool.int %[[VAL_4]] : !torch.int -> !torch.bool +// CHECK: torch.runtime.assert %[[VAL_5]], "Runtime assertion failed for expression u0 >= 3 on node 'ge_1'" +// CHECK: %[[VAL_6:.*]] = torch.aten.gt.int %[[VAL_0]], %[[VAL_1]] : !torch.int, !torch.int -> !torch.bool +// CHECK: %[[VAL_7:.*]] = torch.aten.Int.bool %[[VAL_6]] : !torch.bool -> !torch.int +// CHECK: %[[VAL_8:.*]] = torch.aten.Bool.int %[[VAL_7]] : !torch.int -> !torch.bool +// CHECK: torch.runtime.assert %[[VAL_8]], "Runtime assertion failed for expression 2 < u0 on node 'gt_1'" +// CHECK: return %[[VAL_0]] : !torch.int +// CHECK: } +func.func @torch.aten._assert_scalar(%arg0: !torch.int) -> !torch.int { + %str = torch.constant.str "Runtime assertion failed for expression 2 < u0 on node 'gt_1'" + %int2 = torch.constant.int 2 + %str_0 = torch.constant.str "Runtime assertion failed for expression u0 >= 3 on node 'ge_1'" + %int3 = torch.constant.int 3 + %0 = torch.aten.ge.int %arg0, %int3 : !torch.int, !torch.int -> !torch.bool + %1 = torch.aten.Int.bool %0 : !torch.bool -> !torch.int + torch.aten._assert_scalar %1, %str_0 : !torch.int, !torch.str + %2 = torch.aten.gt.int %arg0, %int2 : !torch.int, !torch.int -> !torch.bool + %3 = torch.aten.Int.bool %2 : !torch.bool -> !torch.int + torch.aten._assert_scalar %3, %str : !torch.int, !torch.str + return %arg0 : !torch.int +}