Skip to content

Commit

Permalink
[torch] GridSample TorchToLinalg lowering (llvm#2883)
Browse files Browse the repository at this point in the history
Lowers `torch.grid_sample` to the equilvalent `linalg` representation.
  • Loading branch information
afalkenberg1 authored Feb 23, 2024
1 parent 5af2495 commit 55dc8de
Show file tree
Hide file tree
Showing 7 changed files with 279 additions and 1 deletion.
27 changes: 27 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -12353,6 +12353,33 @@ def Torch_AtenScaledDotProductAttentionOp : Torch_Op<"aten.scaled_dot_product_at
}];
}

def Torch_AtenGridSamplerOp : Torch_Op<"aten.grid_sampler", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::grid_sampler : (Tensor, Tensor, int, int, bool) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$input,
AnyTorchTensorType:$grid,
Torch_IntType:$interpolation_mode,
Torch_IntType:$padding_mode,
Torch_BoolType:$align_corners
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenGridSamplerOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 5, 1);
}
void AtenGridSamplerOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 5, 1);
}
}];
}

def Torch_Aten__Contains__StrOp : Torch_Op<"aten.__contains__.str", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
168 changes: 168 additions & 0 deletions lib/Conversion/TorchToLinalg/Uncategorized.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2360,6 +2360,172 @@ class ConvertCastEquivalentOp : public OpConversionPattern<OpTy> {
};
} // namespace

namespace {
class ConvertAtenGridSamplerOp : public OpConversionPattern<AtenGridSamplerOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenGridSamplerOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
Type int64type = rewriter.getI64Type();
Type floatType = rewriter.getF32Type();
Value zeroIndex = rewriter.create<arith::ConstantIndexOp>(loc, 0);
Value oneIndex = rewriter.create<arith::ConstantIndexOp>(loc, 1);
Value twoIndex = rewriter.create<arith::ConstantIndexOp>(loc, 2);
Value zeroFloat = rewriter.create<arith::ConstantOp>(
loc, rewriter.getFloatAttr(floatType, 0.0));
Value oneFloat = rewriter.create<arith::ConstantOp>(
loc, rewriter.getFloatAttr(floatType, 1.0));
Value twoFloat = rewriter.create<arith::ConstantOp>(
loc, rewriter.getFloatAttr(floatType, 2.0));
Value input = adaptor.getInput();
auto inputType = input.getType().cast<RankedTensorType>();
auto inputShape = inputType.getShape();
Value innerDim0a = rewriter.create<tensor::DimOp>(loc, input, 2);
Value innerDim1a = rewriter.create<tensor::DimOp>(loc, input, 3);
Value innerDim0b =
rewriter.create<arith::SubIOp>(loc, innerDim0a, oneIndex);
Value innerDim1b =
rewriter.create<arith::SubIOp>(loc, innerDim1a, oneIndex);
Value innerDim0c =
rewriter.create<arith::IndexCastOp>(loc, int64type, innerDim0b);
Value innerDim1c =
rewriter.create<arith::IndexCastOp>(loc, int64type, innerDim1b);
Value innerDim0d =
rewriter.create<arith::SIToFPOp>(loc, floatType, innerDim0c);
Value innerDim1d =
rewriter.create<arith::SIToFPOp>(loc, floatType, innerDim1c);
Value innerDim0e =
rewriter.create<arith::DivFOp>(loc, innerDim0d, twoFloat);
Value innerDim1e =
rewriter.create<arith::DivFOp>(loc, innerDim1d, twoFloat);
Value grid = adaptor.getGrid();
auto gridType = grid.getType().cast<RankedTensorType>();
auto gridShape = gridType.getShape();
auto gridRank = gridType.getRank();
SmallVector<Value> extractGridOffsets0(gridRank, zeroIndex);
SmallVector<Value> extractGridShape = getTensorSizes(rewriter, loc, grid);
SmallVector<Value> extractGridStride(gridRank, oneIndex);
int64_t lastGridDim = gridRank - 1;
extractGridShape[lastGridDim] = oneIndex;
extractGridStride[lastGridDim] = twoIndex;
SmallVector<Value> extractGridOffsets1(gridRank, zeroIndex);
extractGridOffsets1[lastGridDim] = oneIndex;
SmallVector<int64_t> gridShapeExtracted(gridShape);
gridShapeExtracted.back() = 1;
SmallVector<int64_t> gridShapeCollapsed{gridShape[0], gridShape[1],
gridShape[2]};
auto grid0 = rewriter.create<tensor::ExtractSliceOp>(
loc, grid, extractGridOffsets0, extractGridShape, extractGridStride);
auto grid1 = rewriter.create<tensor::ExtractSliceOp>(
loc, grid, extractGridOffsets1, extractGridShape, extractGridStride);
SmallVector<ReassociationIndices> associations{ReassociationIndices{0},
ReassociationIndices{1},
ReassociationIndices{2, 3}};
auto gridCollapsed0 =
rewriter.create<tensor::CollapseShapeOp>(loc, grid0, associations);
auto gridCollapsed1 =
rewriter.create<tensor::CollapseShapeOp>(loc, grid1, associations);
AffineMap gridMap = AffineMap::get(4, 0,
{rewriter.getAffineDimExpr(0),
rewriter.getAffineDimExpr(2),
rewriter.getAffineDimExpr(3)},
op->getContext());
SmallVector<AffineMap> gridMaps{gridMap, gridMap,
rewriter.getMultiDimIdentityMap(gridRank)};
SmallVector<utils::IteratorType> gridIterators(
gridRank, utils::IteratorType::parallel);
SmallVector<int64_t> resultShape{inputShape[0], inputShape[1], gridShape[1],
gridShape[2]};
auto lambdaExtract = [](OpBuilder &b, Location loc, Value input, Value idxA,
Value idxB, Value idxC, Value idxD) -> Value {
SmallVector<Value> index{idxA, idxB, idxC, idxD};
Value result = b.create<tensor::ExtractOp>(loc, input, index);
return result;
};
auto lambdaInter = [&](OpBuilder &b, Location loc, Value x, Value y,
Value d) -> Value {
Value dm = b.create<arith::SubFOp>(loc, oneFloat, d);
Value ra = b.create<arith::MulFOp>(loc, x, dm);
Value rb = b.create<arith::MulFOp>(loc, y, d);
Value res = b.create<arith::AddFOp>(loc, ra, rb);
return res;
};
auto resultType = getTypeConverter()
->convertType(op.getResult().getType())
.cast<RankedTensorType>();
llvm::SmallVector<Value> resultSize{
rewriter.create<tensor::DimOp>(loc, input, 0),
rewriter.create<tensor::DimOp>(loc, input, 1),
rewriter.create<tensor::DimOp>(loc, grid, 1),
rewriter.create<tensor::DimOp>(loc, grid, 2)};
Value resultFinal =
rewriter.create<tensor::EmptyOp>(loc, resultType, resultSize);
auto sGrid = rewriter.create<linalg::GenericOp>(
loc, TypeRange{resultType}, ValueRange{gridCollapsed0, gridCollapsed1},
ValueRange(resultFinal), gridMaps, gridIterators,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value gr0 = args[0];
Value gr1 = args[1];
Value gplus0 = b.create<arith::AddFOp>(loc, gr0, oneFloat);
Value gplus1 = b.create<arith::AddFOp>(loc, gr1, oneFloat);
Value result0 = b.create<arith::MulFOp>(loc, gplus0, innerDim0e);
Value result1 = b.create<arith::MulFOp>(loc, gplus1, innerDim1e);
Value lower0 = b.create<arith::FPToSIOp>(loc, int64type, result0);
Value lower1 = b.create<arith::FPToSIOp>(loc, int64type, result1);
Value oneInt =
b.create<arith::ConstantOp>(loc, b.getIntegerAttr(int64type, 1));
Value upper0 =
b.create<arith::AddIOp>(loc, int64type, lower0, oneInt);
Value upper1 =
b.create<arith::AddIOp>(loc, int64type, lower1, oneInt);
Value notValid0 = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sgt, upper0, innerDim0c);
Value notValid1 = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sgt, upper1, innerDim1c);
Value upperValid0 =
b.create<arith::SelectOp>(loc, notValid0, lower0, upper0);
Value upperValid1 =
b.create<arith::SelectOp>(loc, notValid1, lower1, upper1);
Value lw0 =
b.create<arith::IndexCastOp>(loc, b.getIndexType(), lower0);
Value lw1 =
b.create<arith::IndexCastOp>(loc, b.getIndexType(), lower1);
Value up0 =
b.create<arith::IndexCastOp>(loc, b.getIndexType(), upperValid0);
Value up1 =
b.create<arith::IndexCastOp>(loc, b.getIndexType(), upperValid1);
Value N = b.create<linalg::IndexOp>(loc, 0);
Value C = b.create<linalg::IndexOp>(loc, 1);
Value result00 = lambdaExtract(b, loc, input, N, C, lw0, lw1);
Value result01 = lambdaExtract(b, loc, input, N, C, lw0, up1);
Value result01a =
b.create<arith::SelectOp>(loc, notValid1, zeroFloat, result01);
Value result10 = lambdaExtract(b, loc, input, N, C, up0, lw1);
Value result10a =
b.create<arith::SelectOp>(loc, notValid0, zeroFloat, result10);
Value result11 = lambdaExtract(b, loc, input, N, C, up0, up1);
Value result11a =
b.create<arith::SelectOp>(loc, notValid0, zeroFloat, result11);
Value result11b =
b.create<arith::SelectOp>(loc, notValid1, zeroFloat, result11a);
Value lw0a = b.create<arith::SIToFPOp>(loc, floatType, lower0);
Value lw1a = b.create<arith::SIToFPOp>(loc, floatType, lower1);
Value d0 = b.create<arith::SubFOp>(loc, result0, lw0a);
Value d1 = b.create<arith::SubFOp>(loc, result1, lw1a);
Value resultScaled0 = lambdaInter(b, loc, result00, result01a, d0);
Value resultScaled1 = lambdaInter(b, loc, result10a, result11b, d0);
Value resultScaled =
lambdaInter(b, loc, resultScaled0, resultScaled1, d1);
b.create<linalg::YieldOp>(loc, resultScaled);
});
rewriter.replaceOp(op, sGrid.getResults());
return success();
}
};
} // namespace

void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target) {
Expand Down Expand Up @@ -2412,4 +2578,6 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
typeConverter, context);
target.addIllegalOp<Aten_MakePerTensorQuantizedTensorOp>();
patterns.add<ConvertDequantizePerChannel>(typeConverter, context);
target.addIllegalOp<AtenGridSamplerOp>();
patterns.add<ConvertAtenGridSamplerOp>(typeConverter, context);
}
15 changes: 15 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6597,6 +6597,17 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.grid_sampler\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.int, %arg3: !torch.int, %arg4: !torch.bool) -> !torch.list<int> {\n"
" %int0 = torch.constant.int 0\n"
" %int1 = torch.constant.int 1\n"
" %int2 = torch.constant.int 2\n"
" %0 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
" %1 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
" %2 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
" %3 = torch.aten.__getitem__.t %arg1, %int2 : !torch.list<int>, !torch.int -> !torch.int\n"
" %4 = torch.prim.ListConstruct %0, %1, %2, %3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
" return %4 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.prims.collapse\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.int) -> !torch.list<int> {\n"
" %true = torch.constant.bool true\n"
" %str = torch.constant.str \"AssertionError: start must be less than or equal to end\"\n"
Expand Down Expand Up @@ -9795,6 +9806,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.grid_sampler\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.int, %arg3: !torch.int, %arg4: !torch.bool) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.reflection_pad1d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: padding size expected to be 2\"\n"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,10 @@ def aten〇_make_per_tensor_quantized_tensor〡shape(self: List[int], scale: flo
def prims〇convert_element_type〡shape(a: List[int], dtype: int) -> List[int]:
return upstream_shape_functions.unary(a)

def aten〇grid_sampler〡shape(input: List[int], grid: List[int], interpolation_mode: int, padding_mode: int, align_corners: bool) -> List[int]:
output = [input[0],input[1],grid[1],grid[2]]
return output

def prims〇collapse〡shape(a: List[int], start: int, end: int) -> List[int]:
# Obtained through trial and error on a few examples in PyTorch:
assert start < len(a), "start out of bounds"
Expand Down Expand Up @@ -2152,6 +2156,10 @@ def aten〇constant_pad_nd〡dtype(self_rank_dtype: Tuple[int, int], pad: List[i
self_rank, self_dtype = self_rank_dtype
return self_dtype

def aten〇grid_sampler〡dtype(input_rank_dtype: Tuple[int, int], grid_rank_dtype: Tuple[int, int], interpolation_mode: int, padding_mode: int, align_corners: bool) -> int:
input_rank, input_dtype = input_rank_dtype
grid_rank, grid_dtype = input_rank_dtype
return input_dtype

@check_dtype_function([ErrorInvocation(TensorOfShape(2, 3, 4), padding=1),
ErrorInvocation(TensorOfShape(2, 3, 4), padding=[]),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -714,6 +714,7 @@ def emit_with_mutating_variants(key, **kwargs):
emit("aten::as_strided_scatter : (Tensor, Tensor, int[], int[], int?) -> (Tensor)")
emit("aten::upsample_nearest2d : (Tensor, int[], float?, float?) -> (Tensor)")
emit("aten::scaled_dot_product_attention : (Tensor, Tensor, Tensor, Tensor?, float, bool, float?) -> (Tensor)")
emit("aten::grid_sampler : (Tensor, Tensor, int, int, bool) -> (Tensor)")

# Dict ops.
emit("aten::__contains__.str : (Dict(str, t), str) -> (bool)", has_folder=True)
Expand Down
60 changes: 60 additions & 0 deletions test/Conversion/TorchToLinalg/gridsampler.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
// RUN: torch-mlir-opt <%s -convert-torch-to-linalg -split-input-file -verify-diagnostics | FileCheck %s

// CHECK: #map
// CHECK-LABEL: func @grid_sampler
// CHECK-DAG: %[[TC0:.*]] = torch_c.to_builtin_tensor %[[ARG0:.*]] : !torch.vtensor<[4,10,10,4],f32> -> tensor<4x10x10x4xf32>
// CHECK-DAG: %[[TC1:.*]] = torch_c.to_builtin_tensor %[[ARG1:.*]] : !torch.vtensor<[4,6,8,2],f32> -> tensor<4x6x8x2xf32>
// CHECK-DAG: %[[FALSE:.*]] = torch.constant.bool false
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
// CHECK-DAG: %[[CST1:.*]] = arith.constant 1.000000e+00 : f32
// CHECK-DAG: %[[CST2:.*]] = arith.constant 2.000000e+00 : f32
// CHECK-DAG: %[[C2_3:.*]] = arith.constant 2 : index
// CHECK-DAG: %[[DIM:.*]] = tensor.dim %[[TC0]], %[[C2_3]] : tensor<4x10x10x4xf32>
// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
// CHECK-DAG: %[[DIM_4:.*]] = tensor.dim %[[TC0]], %[[C3]] : tensor<4x10x10x4xf32>
// CHECK-DAG: %[[X2:.*]] = arith.subi %[[DIM:.*]], %[[C1]] : index
// CHECK-DAG: %[[X3:.*]] = arith.subi %[[DIM_4]], %[[C1:.*]] : index
// CHECK-DAG: %[[X4:.*]] = arith.index_cast %[[X2]] : index to i64
// CHECK-DAG: %[[X5:.*]] = arith.index_cast %[[X3]] : index to i64
// CHECK-DAG: %[[X6:.*]] = arith.sitofp %[[X4]] : i64 to f32
// CHECK-DAG: %[[X7:.*]] = arith.sitofp %[[X5]] : i64 to f32
// CHECK-DAG: %[[X8:.*]] = arith.divf %[[X6]], %[[CST2]] : f32
// CHECK-DAG: %[[X9:.*]] = arith.divf %[[X7]], %[[CST2]] : f32
func.func @grid_sampler(%arg0: !torch.vtensor<[4,10,10,4],f32>, %arg1: !torch.vtensor<[4,6,8,2],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
%true = torch.constant.bool 0
%int0 = torch.constant.int 0
%int1 = torch.constant.int 0
%4 = torch.aten.grid_sampler %arg0, %arg1, %int0, %int1, %true : !torch.vtensor<[4,10,10,4],f32>, !torch.vtensor<[4,6,8,2],f32>, !torch.int, !torch.int, !torch.bool -> !torch.vtensor<[?,?,?,?],f32>
return %4 : !torch.vtensor<[?,?,?,?],f32>
}

// -----

// CHECK-LABEL: func @grid_sampler2
// CHECK: #map
// CHECK-DAG: %[[X15:.*]] = arith.mulf %[[X13:.*]], %[[X8:.*]] : f32
// CHECK-DAG: %[[X16:.*]] = arith.mulf %[[X14:.*]], %[[X9:.*]] : f32
// CHECK-DAG: %[[X40:.*]] = arith.mulf %[[EXTRACTED:.*]], %[[X39:.*]] : f32
// CHECK-DAG: %[[X41:.*]] = arith.mulf %[[X31:.*]], %[[X37:.*]] : f32
// CHECK-DAG: %[[X42:.*]] = arith.addf %[[X40:.*]], %[[X41]] : f32
// CHECK-DAG: %[[X43:.*]] = arith.subf %[[CST_1:.*]], %[[X37]] : f32
// CHECK-DAG: %[[X45:.*]] = arith.mulf %[[X34:.*]], %[[X37]] : f32
// CHECK-DAG: %[[X46:.*]] = arith.addf %[[X44:.*]], %[[X45]] : f32
// CHECK-DAG: %[[X47:.*]] = arith.subf %[[CST_1]], %[[X38:.*]] : f32
// CHECK-DAG: %[[X48:.*]] = arith.mulf %[[X42]], %[[XX47:.*]] : f32
// CHECK-DAG: %[[X49:.*]] = arith.mulf %[[X46]], %[[XX38:.*]] : f32
// CHECK-DAG: %[[X50:.*]] = arith.addf %[[X48]], %[[X49]] : f32
// CHECK-DAG: linalg.yield %[[X50]] : f32
// CHECK: } -> tensor<?x?x?x?xf32>
// CHECK: %[[X12:.*]] = torch_c.from_builtin_tensor %[[X11:.*]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32>
// CHECK: return %[[X12]] : !torch.vtensor<[?,?,?,?],f32>
func.func @grid_sampler2(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
%true = torch.constant.bool 0
%int0 = torch.constant.int 0
%int1 = torch.constant.int 0
%4 = torch.aten.grid_sampler %arg0, %arg1, %int0, %int1, %true : !torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[?,?,?,?],f32>, !torch.int, !torch.int, !torch.bool -> !torch.vtensor<[?,?,?,?],f32>
return %4 : !torch.vtensor<[?,?,?,?],f32>
}
1 change: 0 additions & 1 deletion test/Conversion/TorchToLinalg/pooling.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,5 @@ func.func @forward_max_pool3d(%arg0: !torch.vtensor<[?,?,?,?,?],f32>) -> !torch.
// CHECK-NEXT: %[[MAXF:.*]] = arith.maximumf %[[CURRENT_VALUE:.*]], %[[ACC_OUT:.*]] : f32
// CHECK-NEXT: linalg.yield %[[MAXF:.*]] : f32
// CHECK: } -> tensor<?x?x?x?x?xf32>

return %4 : !torch.vtensor<[?,?,?,?,?],f32>
}

0 comments on commit 55dc8de

Please sign in to comment.