From 55dc8deb9221c9ec0fe2a991542f2f788c62a3e1 Mon Sep 17 00:00:00 2001 From: Andreas Falkenberg <149819731+afalkenberg1@users.noreply.github.com> Date: Fri, 23 Feb 2024 09:14:38 -0800 Subject: [PATCH] [torch] GridSample TorchToLinalg lowering (#2883) Lowers `torch.grid_sample` to the equilvalent `linalg` representation. --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 27 +++ .../TorchToLinalg/Uncategorized.cpp | 168 ++++++++++++++++++ .../Transforms/AbstractInterpLibrary.cpp | 15 ++ .../build_tools/abstract_interp_lib_gen.py | 8 + .../build_tools/torch_ods_gen.py | 1 + .../Conversion/TorchToLinalg/gridsampler.mlir | 60 +++++++ test/Conversion/TorchToLinalg/pooling.mlir | 1 - 7 files changed, 279 insertions(+), 1 deletion(-) create mode 100644 test/Conversion/TorchToLinalg/gridsampler.mlir diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index c5fec66913b0..cc8be7c6910b 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -12353,6 +12353,33 @@ def Torch_AtenScaledDotProductAttentionOp : Torch_Op<"aten.scaled_dot_product_at }]; } +def Torch_AtenGridSamplerOp : Torch_Op<"aten.grid_sampler", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::grid_sampler : (Tensor, Tensor, int, int, bool) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$input, + AnyTorchTensorType:$grid, + Torch_IntType:$interpolation_mode, + Torch_IntType:$padding_mode, + Torch_BoolType:$align_corners + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenGridSamplerOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 5, 1); + } + void AtenGridSamplerOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 5, 1); + } + }]; +} + def Torch_Aten__Contains__StrOp : Torch_Op<"aten.__contains__.str", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 08d69ca718b9..ed6883000cf9 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -2360,6 +2360,172 @@ class ConvertCastEquivalentOp : public OpConversionPattern { }; } // namespace +namespace { +class ConvertAtenGridSamplerOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenGridSamplerOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + Type int64type = rewriter.getI64Type(); + Type floatType = rewriter.getF32Type(); + Value zeroIndex = rewriter.create(loc, 0); + Value oneIndex = rewriter.create(loc, 1); + Value twoIndex = rewriter.create(loc, 2); + Value zeroFloat = rewriter.create( + loc, rewriter.getFloatAttr(floatType, 0.0)); + Value oneFloat = rewriter.create( + loc, rewriter.getFloatAttr(floatType, 1.0)); + Value twoFloat = rewriter.create( + loc, rewriter.getFloatAttr(floatType, 2.0)); + Value input = adaptor.getInput(); + auto inputType = input.getType().cast(); + auto inputShape = inputType.getShape(); + Value innerDim0a = rewriter.create(loc, input, 2); + Value innerDim1a = rewriter.create(loc, input, 3); + Value innerDim0b = + rewriter.create(loc, innerDim0a, oneIndex); + Value innerDim1b = + rewriter.create(loc, innerDim1a, oneIndex); + Value innerDim0c = + rewriter.create(loc, int64type, innerDim0b); + Value innerDim1c = + rewriter.create(loc, int64type, innerDim1b); + Value innerDim0d = + rewriter.create(loc, floatType, innerDim0c); + Value innerDim1d = + rewriter.create(loc, floatType, innerDim1c); + Value innerDim0e = + rewriter.create(loc, innerDim0d, twoFloat); + Value innerDim1e = + rewriter.create(loc, innerDim1d, twoFloat); + Value grid = adaptor.getGrid(); + auto gridType = grid.getType().cast(); + auto gridShape = gridType.getShape(); + auto gridRank = gridType.getRank(); + SmallVector extractGridOffsets0(gridRank, zeroIndex); + SmallVector extractGridShape = getTensorSizes(rewriter, loc, grid); + SmallVector extractGridStride(gridRank, oneIndex); + int64_t lastGridDim = gridRank - 1; + extractGridShape[lastGridDim] = oneIndex; + extractGridStride[lastGridDim] = twoIndex; + SmallVector extractGridOffsets1(gridRank, zeroIndex); + extractGridOffsets1[lastGridDim] = oneIndex; + SmallVector gridShapeExtracted(gridShape); + gridShapeExtracted.back() = 1; + SmallVector gridShapeCollapsed{gridShape[0], gridShape[1], + gridShape[2]}; + auto grid0 = rewriter.create( + loc, grid, extractGridOffsets0, extractGridShape, extractGridStride); + auto grid1 = rewriter.create( + loc, grid, extractGridOffsets1, extractGridShape, extractGridStride); + SmallVector associations{ReassociationIndices{0}, + ReassociationIndices{1}, + ReassociationIndices{2, 3}}; + auto gridCollapsed0 = + rewriter.create(loc, grid0, associations); + auto gridCollapsed1 = + rewriter.create(loc, grid1, associations); + AffineMap gridMap = AffineMap::get(4, 0, + {rewriter.getAffineDimExpr(0), + rewriter.getAffineDimExpr(2), + rewriter.getAffineDimExpr(3)}, + op->getContext()); + SmallVector gridMaps{gridMap, gridMap, + rewriter.getMultiDimIdentityMap(gridRank)}; + SmallVector gridIterators( + gridRank, utils::IteratorType::parallel); + SmallVector resultShape{inputShape[0], inputShape[1], gridShape[1], + gridShape[2]}; + auto lambdaExtract = [](OpBuilder &b, Location loc, Value input, Value idxA, + Value idxB, Value idxC, Value idxD) -> Value { + SmallVector index{idxA, idxB, idxC, idxD}; + Value result = b.create(loc, input, index); + return result; + }; + auto lambdaInter = [&](OpBuilder &b, Location loc, Value x, Value y, + Value d) -> Value { + Value dm = b.create(loc, oneFloat, d); + Value ra = b.create(loc, x, dm); + Value rb = b.create(loc, y, d); + Value res = b.create(loc, ra, rb); + return res; + }; + auto resultType = getTypeConverter() + ->convertType(op.getResult().getType()) + .cast(); + llvm::SmallVector resultSize{ + rewriter.create(loc, input, 0), + rewriter.create(loc, input, 1), + rewriter.create(loc, grid, 1), + rewriter.create(loc, grid, 2)}; + Value resultFinal = + rewriter.create(loc, resultType, resultSize); + auto sGrid = rewriter.create( + loc, TypeRange{resultType}, ValueRange{gridCollapsed0, gridCollapsed1}, + ValueRange(resultFinal), gridMaps, gridIterators, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value gr0 = args[0]; + Value gr1 = args[1]; + Value gplus0 = b.create(loc, gr0, oneFloat); + Value gplus1 = b.create(loc, gr1, oneFloat); + Value result0 = b.create(loc, gplus0, innerDim0e); + Value result1 = b.create(loc, gplus1, innerDim1e); + Value lower0 = b.create(loc, int64type, result0); + Value lower1 = b.create(loc, int64type, result1); + Value oneInt = + b.create(loc, b.getIntegerAttr(int64type, 1)); + Value upper0 = + b.create(loc, int64type, lower0, oneInt); + Value upper1 = + b.create(loc, int64type, lower1, oneInt); + Value notValid0 = rewriter.create( + loc, arith::CmpIPredicate::sgt, upper0, innerDim0c); + Value notValid1 = rewriter.create( + loc, arith::CmpIPredicate::sgt, upper1, innerDim1c); + Value upperValid0 = + b.create(loc, notValid0, lower0, upper0); + Value upperValid1 = + b.create(loc, notValid1, lower1, upper1); + Value lw0 = + b.create(loc, b.getIndexType(), lower0); + Value lw1 = + b.create(loc, b.getIndexType(), lower1); + Value up0 = + b.create(loc, b.getIndexType(), upperValid0); + Value up1 = + b.create(loc, b.getIndexType(), upperValid1); + Value N = b.create(loc, 0); + Value C = b.create(loc, 1); + Value result00 = lambdaExtract(b, loc, input, N, C, lw0, lw1); + Value result01 = lambdaExtract(b, loc, input, N, C, lw0, up1); + Value result01a = + b.create(loc, notValid1, zeroFloat, result01); + Value result10 = lambdaExtract(b, loc, input, N, C, up0, lw1); + Value result10a = + b.create(loc, notValid0, zeroFloat, result10); + Value result11 = lambdaExtract(b, loc, input, N, C, up0, up1); + Value result11a = + b.create(loc, notValid0, zeroFloat, result11); + Value result11b = + b.create(loc, notValid1, zeroFloat, result11a); + Value lw0a = b.create(loc, floatType, lower0); + Value lw1a = b.create(loc, floatType, lower1); + Value d0 = b.create(loc, result0, lw0a); + Value d1 = b.create(loc, result1, lw1a); + Value resultScaled0 = lambdaInter(b, loc, result00, result01a, d0); + Value resultScaled1 = lambdaInter(b, loc, result10a, result11b, d0); + Value resultScaled = + lambdaInter(b, loc, resultScaled0, resultScaled1, d1); + b.create(loc, resultScaled); + }); + rewriter.replaceOp(op, sGrid.getResults()); + return success(); + } +}; +} // namespace + void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target) { @@ -2412,4 +2578,6 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality( typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); } diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 39813da66e85..bfc2fc6a1d0c 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -6597,6 +6597,17 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.grid_sampler\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.int, %arg3: !torch.int, %arg4: !torch.bool) -> !torch.list {\n" +" %int0 = torch.constant.int 0\n" +" %int1 = torch.constant.int 1\n" +" %int2 = torch.constant.int 2\n" +" %0 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %1 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %2 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %3 = torch.aten.__getitem__.t %arg1, %int2 : !torch.list, !torch.int -> !torch.int\n" +" %4 = torch.prim.ListConstruct %0, %1, %2, %3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list\n" +" return %4 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.prims.collapse\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.int) -> !torch.list {\n" " %true = torch.constant.bool true\n" " %str = torch.constant.str \"AssertionError: start must be less than or equal to end\"\n" @@ -9795,6 +9806,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.grid_sampler\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.int, %arg3: !torch.int, %arg4: !torch.bool) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.reflection_pad1d\"(%arg0: !torch.tuple, %arg1: !torch.list) -> !torch.int {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: padding size expected to be 2\"\n" diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 1a87bbb6bee1..403d124ad927 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -287,6 +287,10 @@ def aten〇_make_per_tensor_quantized_tensor〡shape(self: List[int], scale: flo def prims〇convert_element_type〡shape(a: List[int], dtype: int) -> List[int]: return upstream_shape_functions.unary(a) +def aten〇grid_sampler〡shape(input: List[int], grid: List[int], interpolation_mode: int, padding_mode: int, align_corners: bool) -> List[int]: + output = [input[0],input[1],grid[1],grid[2]] + return output + def prims〇collapse〡shape(a: List[int], start: int, end: int) -> List[int]: # Obtained through trial and error on a few examples in PyTorch: assert start < len(a), "start out of bounds" @@ -2152,6 +2156,10 @@ def aten〇constant_pad_nd〡dtype(self_rank_dtype: Tuple[int, int], pad: List[i self_rank, self_dtype = self_rank_dtype return self_dtype +def aten〇grid_sampler〡dtype(input_rank_dtype: Tuple[int, int], grid_rank_dtype: Tuple[int, int], interpolation_mode: int, padding_mode: int, align_corners: bool) -> int: + input_rank, input_dtype = input_rank_dtype + grid_rank, grid_dtype = input_rank_dtype + return input_dtype @check_dtype_function([ErrorInvocation(TensorOfShape(2, 3, 4), padding=1), ErrorInvocation(TensorOfShape(2, 3, 4), padding=[]), diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 64f03add759e..51c196421b78 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -714,6 +714,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::as_strided_scatter : (Tensor, Tensor, int[], int[], int?) -> (Tensor)") emit("aten::upsample_nearest2d : (Tensor, int[], float?, float?) -> (Tensor)") emit("aten::scaled_dot_product_attention : (Tensor, Tensor, Tensor, Tensor?, float, bool, float?) -> (Tensor)") + emit("aten::grid_sampler : (Tensor, Tensor, int, int, bool) -> (Tensor)") # Dict ops. emit("aten::__contains__.str : (Dict(str, t), str) -> (bool)", has_folder=True) diff --git a/test/Conversion/TorchToLinalg/gridsampler.mlir b/test/Conversion/TorchToLinalg/gridsampler.mlir new file mode 100644 index 000000000000..d392860fa2c1 --- /dev/null +++ b/test/Conversion/TorchToLinalg/gridsampler.mlir @@ -0,0 +1,60 @@ +// RUN: torch-mlir-opt <%s -convert-torch-to-linalg -split-input-file -verify-diagnostics | FileCheck %s + +// CHECK: #map +// CHECK-LABEL: func @grid_sampler +// CHECK-DAG: %[[TC0:.*]] = torch_c.to_builtin_tensor %[[ARG0:.*]] : !torch.vtensor<[4,10,10,4],f32> -> tensor<4x10x10x4xf32> +// CHECK-DAG: %[[TC1:.*]] = torch_c.to_builtin_tensor %[[ARG1:.*]] : !torch.vtensor<[4,6,8,2],f32> -> tensor<4x6x8x2xf32> +// CHECK-DAG: %[[FALSE:.*]] = torch.constant.bool false +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %[[CST1:.*]] = arith.constant 1.000000e+00 : f32 +// CHECK-DAG: %[[CST2:.*]] = arith.constant 2.000000e+00 : f32 +// CHECK-DAG: %[[C2_3:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[DIM:.*]] = tensor.dim %[[TC0]], %[[C2_3]] : tensor<4x10x10x4xf32> +// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index +// CHECK-DAG: %[[DIM_4:.*]] = tensor.dim %[[TC0]], %[[C3]] : tensor<4x10x10x4xf32> +// CHECK-DAG: %[[X2:.*]] = arith.subi %[[DIM:.*]], %[[C1]] : index +// CHECK-DAG: %[[X3:.*]] = arith.subi %[[DIM_4]], %[[C1:.*]] : index +// CHECK-DAG: %[[X4:.*]] = arith.index_cast %[[X2]] : index to i64 +// CHECK-DAG: %[[X5:.*]] = arith.index_cast %[[X3]] : index to i64 +// CHECK-DAG: %[[X6:.*]] = arith.sitofp %[[X4]] : i64 to f32 +// CHECK-DAG: %[[X7:.*]] = arith.sitofp %[[X5]] : i64 to f32 +// CHECK-DAG: %[[X8:.*]] = arith.divf %[[X6]], %[[CST2]] : f32 +// CHECK-DAG: %[[X9:.*]] = arith.divf %[[X7]], %[[CST2]] : f32 +func.func @grid_sampler(%arg0: !torch.vtensor<[4,10,10,4],f32>, %arg1: !torch.vtensor<[4,6,8,2],f32>) -> !torch.vtensor<[?,?,?,?],f32> { + %true = torch.constant.bool 0 + %int0 = torch.constant.int 0 + %int1 = torch.constant.int 0 + %4 = torch.aten.grid_sampler %arg0, %arg1, %int0, %int1, %true : !torch.vtensor<[4,10,10,4],f32>, !torch.vtensor<[4,6,8,2],f32>, !torch.int, !torch.int, !torch.bool -> !torch.vtensor<[?,?,?,?],f32> + return %4 : !torch.vtensor<[?,?,?,?],f32> +} + +// ----- + +// CHECK-LABEL: func @grid_sampler2 +// CHECK: #map +// CHECK-DAG: %[[X15:.*]] = arith.mulf %[[X13:.*]], %[[X8:.*]] : f32 +// CHECK-DAG: %[[X16:.*]] = arith.mulf %[[X14:.*]], %[[X9:.*]] : f32 +// CHECK-DAG: %[[X40:.*]] = arith.mulf %[[EXTRACTED:.*]], %[[X39:.*]] : f32 +// CHECK-DAG: %[[X41:.*]] = arith.mulf %[[X31:.*]], %[[X37:.*]] : f32 +// CHECK-DAG: %[[X42:.*]] = arith.addf %[[X40:.*]], %[[X41]] : f32 +// CHECK-DAG: %[[X43:.*]] = arith.subf %[[CST_1:.*]], %[[X37]] : f32 +// CHECK-DAG: %[[X45:.*]] = arith.mulf %[[X34:.*]], %[[X37]] : f32 +// CHECK-DAG: %[[X46:.*]] = arith.addf %[[X44:.*]], %[[X45]] : f32 +// CHECK-DAG: %[[X47:.*]] = arith.subf %[[CST_1]], %[[X38:.*]] : f32 +// CHECK-DAG: %[[X48:.*]] = arith.mulf %[[X42]], %[[XX47:.*]] : f32 +// CHECK-DAG: %[[X49:.*]] = arith.mulf %[[X46]], %[[XX38:.*]] : f32 +// CHECK-DAG: %[[X50:.*]] = arith.addf %[[X48]], %[[X49]] : f32 +// CHECK-DAG: linalg.yield %[[X50]] : f32 +// CHECK: } -> tensor +// CHECK: %[[X12:.*]] = torch_c.from_builtin_tensor %[[X11:.*]] : tensor -> !torch.vtensor<[?,?,?,?],f32> +// CHECK: return %[[X12]] : !torch.vtensor<[?,?,?,?],f32> +func.func @grid_sampler2(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { + %true = torch.constant.bool 0 + %int0 = torch.constant.int 0 + %int1 = torch.constant.int 0 + %4 = torch.aten.grid_sampler %arg0, %arg1, %int0, %int1, %true : !torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[?,?,?,?],f32>, !torch.int, !torch.int, !torch.bool -> !torch.vtensor<[?,?,?,?],f32> + return %4 : !torch.vtensor<[?,?,?,?],f32> +} \ No newline at end of file diff --git a/test/Conversion/TorchToLinalg/pooling.mlir b/test/Conversion/TorchToLinalg/pooling.mlir index 8ed75f648f5e..8a359ed5627d 100644 --- a/test/Conversion/TorchToLinalg/pooling.mlir +++ b/test/Conversion/TorchToLinalg/pooling.mlir @@ -71,6 +71,5 @@ func.func @forward_max_pool3d(%arg0: !torch.vtensor<[?,?,?,?,?],f32>) -> !torch. // CHECK-NEXT: %[[MAXF:.*]] = arith.maximumf %[[CURRENT_VALUE:.*]], %[[ACC_OUT:.*]] : f32 // CHECK-NEXT: linalg.yield %[[MAXF:.*]] : f32 // CHECK: } -> tensor - return %4 : !torch.vtensor<[?,?,?,?,?],f32> }