Skip to content

Commit

Permalink
AtenAdaptiveMaxPool2d Conversion to Linalg (#2779)
Browse files Browse the repository at this point in the history
The logic here is very similar to the conversion for AdaptiveAvgPool1d
#2661 with a few modifications:

1. buffVal = -inf instead of 0
2. the main linalg generic op accumulates a max, instead of a sum, to
the first output tensor
3. avg pooling requires dividing the sum pool by the kernel width, which
we stored as an auxilliary tensor (kSizeTensor). Here, the auxiliary
tensor will be recording the indices. Strangely enough, the only
signature available for this function is to return indices, and it
appears that they must be computed whether the user desires them or not.
See
[pytorch/torch/nn/functional.py](https://github.com/pytorch/pytorch/blob/main/torch/nn/functional.py#L1174).

Before writing other adaptive pooling conversions, the logic of this
decomposition should be rolled into a helper function that will work for
both max and avg pooling ops. Even the auxiliary tensor should likely be
automated. This code was written in a slightly more tedious way than
strictly necessary (often using loops to fill SmallVectors up to rank-2,
which is only two in this case), in order to more easily facilitate the
transition to a helper function.
  • Loading branch information
zjgarvey authored Jan 24, 2024
1 parent 311b6b0 commit c531f54
Show file tree
Hide file tree
Showing 6 changed files with 407 additions and 0 deletions.
25 changes: 25 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -6856,6 +6856,31 @@ def Torch_Aten_AdaptiveAvgPool3dBackwardOp : Torch_Op<"aten._adaptive_avg_pool3d
}];
}

def Torch_AtenAdaptiveMaxPool2dOp : Torch_Op<"aten.adaptive_max_pool2d", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::adaptive_max_pool2d : (Tensor, int[]) -> (Tensor, Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchListOfTorchIntType:$output_size
);
let results = (outs
AnyTorchTensorType:$result0,
AnyTorchTensorType:$result1
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenAdaptiveMaxPool2dOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 2);
}
void AtenAdaptiveMaxPool2dOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 2);
}
}];
}

def Torch_AtenTopkOp : Torch_Op<"aten.topk", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
204 changes: 204 additions & 0 deletions lib/Conversion/TorchToLinalg/Pooling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -793,6 +793,208 @@ class ConvertAtenAdaptiveAvgPool1dOp
};
} // namespace

// The logic for this conversion is similar to the AdaptiveAvgPool1dOp
// conversion. Before writing any more adaptive pooling conversions, the logic
// in this should be off-loaded to a helper function, since each of the adaptive
// ops are essentially the same with some minor tweaks. Instead of kSizeTensor,
// we named the additional output of the linalg generic op auxTensor.
// For max pooling, auxTensor holds the indices of max values, and for
// avg pooling, the auxTensor will be kSizeTensor, used to later divide the
// sum pool by the kernel size.
namespace {
class ConvertAtenAdaptiveMaxPool2dOp
: public OpConversionPattern<AtenAdaptiveMaxPool2dOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenAdaptiveMaxPool2dOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

Location loc = op->getLoc();
const TypeConverter *typeConverter = getTypeConverter();

// get rank of input (same as rank of output)
int64_t rank =
adaptor.getSelf().getType().cast<RankedTensorType>().getRank();
// input operand should be NCHW (i.e. rank 4)
if (rank != 4) {
return rewriter.notifyMatchFailure(op, "only supports input type NCHW");
}

// input tensor and output shape
Value input = adaptor.getSelf();
Value outputShape = op.getOutputSize();
SmallVector<Value> outShapeVector;
getListConstructElements(outputShape, outShapeVector);
outShapeVector =
getTypeConvertedValues(rewriter, loc, typeConverter, outShapeVector);
SmallVector<Value> inputSpatialSizes;
for (unsigned i = 2; i < rank; i++) {
inputSpatialSizes.push_back(getDimOp(rewriter, loc, input, i));
}
SmallVector<Value> outShapeIndexVector;
for (auto v : outShapeVector) {
outShapeIndexVector.push_back(castIntToIndex(rewriter, loc, v));
}
RankedTensorType inputType = input.getType().cast<RankedTensorType>();
RankedTensorType outputType =
typeConverter->convertType(op.getResult0().getType())
.cast<RankedTensorType>();

// get elementType of input tensor
Type elementType = inputType.getElementType();

// make an iteration space of size kMax = 1 + ceildiv (hIn - 1) , hOut
Type boolType = rewriter.getI1Type();
SmallVector<Value> kIterSizeVector;
Value constantOne =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
for (int i = 0; i < rank - 2; i++) {
Value hInPlusOne = rewriter.create<arith::SubIOp>(
loc, inputSpatialSizes[i], constantOne);
Value kMaxMinusOne = rewriter.create<arith::CeilDivSIOp>(
loc, hInPlusOne, outShapeIndexVector[i]);
Value kMax =
rewriter.create<arith::AddIOp>(loc, constantOne, kMaxMinusOne);
kIterSizeVector.push_back(kMax);
}
Value kIter = rewriter.create<tensor::EmptyOp>(
loc, getAsOpFoldResult(kIterSizeVector), boolType);

// need to buffer input, else there will possibly be an out of bounds access
// later buffVal = 0 for avg pooling and -inf for max pooling
auto smallestFPValueAttr = rewriter.getFloatAttr(
elementType,
APFloat::getInf(elementType.cast<mlir::FloatType>().getFloatSemantics(),
/*Negative=*/true));
Value buffVal = rewriter.create<arith::ConstantOp>(loc, elementType,
smallestFPValueAttr);
SmallVector<int64_t> lowPadding(rank, 0);
SmallVector<int64_t> highPadding(2, 0);
for (int i = 0; i < rank - 2; i++) {
highPadding.push_back(1);
}
Value buffInput = torch_to_linalg::getPaddedTensor(
op, rewriter, input, lowPadding, highPadding, buffVal);

// make a list of outputSizes
SmallVector<Value> outputSizes;
for (unsigned i = 0; i < 2; i++) {
outputSizes.push_back(getDimOp(rewriter, loc, input, i));
}
for (unsigned i = 2; i < rank; i++) {
outputSizes.push_back(outShapeIndexVector[i - 2]);
}

// for avg pooling the auxTensor should hold kernel widths (kSizeTensor)
// for max Pooling, it should hold the indices
RankedTensorType outputType1 =
typeConverter->convertType(op.getResult1().getType())
.cast<RankedTensorType>();
Type indicesType = outputType1.getElementType();
Value auxTensor = rewriter.create<tensor::EmptyOp>(
loc, getAsOpFoldResult(outputSizes), indicesType);

// initialize an output tensor
Value initOutput =
createInitTensor(rewriter, loc, outputSizes, elementType, buffVal);

// setup indexing maps and iterator types for linalg generic op (outputShape
// (rank),kIter (rank -2)) for kIter (d0,d1,d2,d3,d4,d5) -> (d4,d5) for
// output (d0,d1,d2,d3,d4,d5) -> (d0,d1,d2,d3) for auxTensor
// (d0,d1,d2,d3,d4,d5) -> (d0,d1,d2,d3) (or (d2,d3) for avg pooling)
SmallVector<AffineExpr> kIterExprs, outputExprs, auxTensorExprs;
// batch + channel + output spatial dims
for (unsigned i = 0; i < rank; i++) {
outputExprs.push_back(rewriter.getAffineDimExpr(i));
auxTensorExprs.push_back(rewriter.getAffineDimExpr(i));
}
// kIter covers last rank-2 indices
for (unsigned i = rank; i < 2 * rank - 2; i++) {
kIterExprs.push_back(rewriter.getAffineDimExpr(i));
}
SmallVector<AffineMap> indexingMaps =
AffineMap::inferFromExprList({kIterExprs, outputExprs, auxTensorExprs});
SmallVector<utils::IteratorType> iteratorTypes(
rank, utils::IteratorType::parallel);
for (unsigned i = 0; i < rank - 2; i++) {
iteratorTypes.push_back(utils::IteratorType::reduction);
}
Value indexOne = rewriter.create<arith::ConstantIndexOp>(loc, 1);
auto maxPool = rewriter.create<linalg::GenericOp>(
loc, /*resultTensorTypes=*/
TypeRange({initOutput.getType(), auxTensor.getType()}),
/*inputs=*/ValueRange({kIter}),
/*outputs=*/ValueRange({initOutput, auxTensor}),
/*indexingMaps=*/indexingMaps,
/*iteratorTypes=*/iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value res = args[1];
Value maxIndex = args[2];
SmallVector<Value> ind;
for (unsigned i = 0; i < 2 * rank - 2; i++) {
ind.push_back(b.create<linalg::IndexOp>(loc, i));
}
// compute start and end indices
// st = s1( s0(ind2 * Hin) // Hout )
SmallVector<Value> starts;
SmallVector<Value> ends;
for (unsigned i = 2; i < rank; i++) {
Value s0 =
b.create<arith::MulIOp>(loc, ind[i], inputSpatialSizes[i - 2]);
Value s1 = b.create<arith::FloorDivSIOp>(
loc, s0, outShapeIndexVector[i - 2]);
starts.push_back(s1);
// en = e4( 1 + e3( e2( e1( e0(ind2 + 1) * hIn ) - 1 ) // hOut ) )
Value e0 = b.create<arith::AddIOp>(loc, ind[i], indexOne);
Value e1 =
b.create<arith::MulIOp>(loc, e0, inputSpatialSizes[i - 2]);
Value e2 = b.create<arith::SubIOp>(loc, e1, indexOne);
Value e3 = b.create<arith::FloorDivSIOp>(
loc, e2, outShapeIndexVector[i - 2]);
Value e4 = b.create<arith::AddIOp>(loc, indexOne, e3);
ends.push_back(e4);
}
SmallVector<Value> inputElementIndices;
inputElementIndices.push_back(ind[0]);
inputElementIndices.push_back(ind[1]);
for (unsigned i = 2; i < rank; i++) {
inputElementIndices.push_back(
b.create<arith::AddIOp>(loc, starts[i - 2], ind[rank - 2 + i]));
}
Value inElt = b.create<tensor::ExtractOp>(loc, elementType, buffInput,
inputElementIndices);
// check if we extracted at windex < end index
for (unsigned i = 0; i < rank - 2; i++) {
Value cond =
b.create<arith::CmpIOp>(loc, arith::CmpIPredicate(6),
inputElementIndices[i + 2], ends[i]);
inElt = b.create<arith::SelectOp>(loc, cond, inElt, buffVal);
}
Value cond1 = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGT,
inElt, res);
// index location is (ih * input_width + iw)
Value indexOut0 = b.create<arith::MulIOp>(loc, inputElementIndices[2],
inputSpatialSizes[1]);
Value indexOut1 =
b.create<arith::AddIOp>(loc, indexOut0, inputElementIndices[3]);
Value indexOut1Int = castIndexToInt64(b, loc, indexOut1);
Value indexOut2 =
b.create<arith::SelectOp>(loc, cond1, indexOut1Int, maxIndex);
Value out2 = b.create<arith::SelectOp>(loc, cond1, inElt, res);
b.create<linalg::YieldOp>(loc, ValueRange({out2, indexOut2}));
});

Value maxValues = rewriter.create<tensor::CastOp>(
loc, outputType, maxPool.getResultTensors()[0]);
Value outputIndices = rewriter.create<tensor::CastOp>(
loc, outputType1, maxPool.getResultTensors()[1]);
rewriter.replaceOp(op, {maxValues, outputIndices});
return success();
}
};
} // namespace

void mlir::torch::torch_to_linalg::populatePoolingPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target) {
Expand All @@ -813,4 +1015,6 @@ void mlir::torch::torch_to_linalg::populatePoolingPatternsAndLegality(
typeConverter, context);
target.addIllegalOp<AtenAdaptiveAvgPool1dOp>();
patterns.add<ConvertAtenAdaptiveAvgPool1dOp>(typeConverter, context);
target.addIllegalOp<AtenAdaptiveMaxPool2dOp>();
patterns.add<ConvertAtenAdaptiveMaxPool2dOp>(typeConverter, context);
}
73 changes: 73 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7964,6 +7964,73 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.adaptive_avg_pool2d(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.adaptive_max_pool2d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.tuple<list<int>, list<int>> {\n"
" %0 = call @__torch__.adaptive_max_pool2d(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.tuple<list<int>, list<int>>\n"
" return %0 : !torch.tuple<list<int>, list<int>>\n"
" }\n"
" func.func @__torch__.adaptive_max_pool2d(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.tuple<list<int>, list<int>> {\n"
" %true = torch.constant.bool true\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %int2 = torch.constant.int 2\n"
" %int3 = torch.constant.int 3\n"
" %int4 = torch.constant.int 4\n"
" %int0 = torch.constant.int 0\n"
" %0 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int\n"
" %1 = torch.aten.eq.int %0, %int2 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %1 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %2 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
" %3 = torch.aten.eq.int %2, %int3 : !torch.int, !torch.int -> !torch.bool\n"
" %4 = torch.prim.If %3 -> (!torch.bool) {\n"
" torch.prim.If.yield %true : !torch.bool\n"
" } else {\n"
" %11 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
" %12 = torch.aten.eq.int %11, %int4 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %12 : !torch.bool\n"
" }\n"
" torch.prim.If %4 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %5 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
" torch.prim.Loop %5, %true, init() {\n"
" ^bb0(%arg2: !torch.int):\n"
" %11 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list<int>, !torch.int -> !torch.int\n"
" %12 = torch.aten.ne.int %11, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %12 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" torch.prim.Loop.condition %true, iter()\n"
" } : (!torch.int, !torch.bool) -> ()\n"
" %6 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
" %7 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
" %8 = torch.aten.sub.int %7, %int2 : !torch.int, !torch.int -> !torch.int\n"
" torch.prim.Loop %8, %true, init() {\n"
" ^bb0(%arg2: !torch.int):\n"
" %11 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list<int>, !torch.int -> !torch.int\n"
" %12 = torch.aten.append.t %6, %11 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
" torch.prim.Loop.condition %true, iter()\n"
" } : (!torch.int, !torch.bool) -> ()\n"
" %9 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int\n"
" torch.prim.Loop %9, %true, init() {\n"
" ^bb0(%arg2: !torch.int):\n"
" %11 = torch.aten.__getitem__.t %arg1, %arg2 : !torch.list<int>, !torch.int -> !torch.int\n"
" %12 = torch.aten.append.t %6, %11 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
" torch.prim.Loop.condition %true, iter()\n"
" } : (!torch.int, !torch.bool) -> ()\n"
" %10 = torch.prim.TupleConstruct %6, %6 : !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>>\n"
" return %10 : !torch.tuple<list<int>, list<int>>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.flatten.using_ints\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.int) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.flatten(%arg0, %arg1, %arg2) : (!torch.list<int>, !torch.int, !torch.int) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
Expand Down Expand Up @@ -9896,6 +9963,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %1 = torch.prim.TupleConstruct %0#1, %int4 : !torch.int, !torch.int -> !torch.tuple<int, int>\n"
" return %1 : !torch.tuple<int, int>\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.adaptive_max_pool2d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>) -> !torch.tuple<int, int> {\n"
" %int4 = torch.constant.int 4\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1 = torch.prim.TupleConstruct %0#1, %int4 : !torch.int, !torch.int -> !torch.tuple<int, int>\n"
" return %1 : !torch.tuple<int, int>\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.mish\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -902,6 +902,24 @@ def aten〇avg_pool2d〡shape(self: List[int], kernel_size: List[int], stride: L
def aten〇adaptive_avg_pool2d〡shape(self: List[int], output_size: List[int]) -> List[int]:
return upstream_shape_functions.adaptive_avg_pool2d(self, output_size)

def adaptive_max_pool2d(self: List[int], out: List[int]):
assert len(out) == 2
assert len(self) == 3 or len(self) == 4

for i in range(len(self)):
assert self[i] != 0

shape: List[int] = []
for i in range(len(self) - 2):
shape.append(self[i])
for j in range(len(out)):
shape.append(out[j])

return shape, shape

def aten〇adaptive_max_pool2d〡shape(self: List[int], output_size: List[int]) -> Tuple[List[int], List[int]]:
return adaptive_max_pool2d(self, output_size)

def aten〇flatten〇using_ints〡shape(self: List[int], start_dim: int = 0, end_dim: int = -1) -> List[int]:
return upstream_shape_functions.flatten(self, start_dim, end_dim)

Expand Down Expand Up @@ -2334,6 +2352,11 @@ def aten〇max_pool2d_with_indices〡dtype(self_rank_dtype: Tuple[int, int], ker
self_rank, self_dtype = self_rank_dtype
return self_dtype, torch.int64

@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7)], output_size=[2, 2]))
def aten〇adaptive_max_pool2d〡dtype(self_rank_dtype: Tuple[int, int], output_size: List[int]) -> Tuple[int, int]:
self_rank, self_dtype = self_rank_dtype
return self_dtype, torch.int64

@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
def aten〇mish〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
self_rank, self_dtype = self_rank_dtype
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,7 @@ def emit_with_mutating_variants(key, **kwargs):
emit("aten::adaptive_avg_pool3d : (Tensor, int[]) -> (Tensor)")
emit("aten::_adaptive_avg_pool3d : (Tensor, int[]) -> (Tensor)")
emit("aten::_adaptive_avg_pool3d_backward : (Tensor, Tensor) -> (Tensor)")
emit("aten::adaptive_max_pool2d : (Tensor, int[]) -> (Tensor, Tensor)")
emit("aten::topk : (Tensor, int, int, bool, bool) -> (Tensor, Tensor)")
emit("aten::transpose.int : (Tensor, int, int) -> (Tensor)")
emit("aten::pixel_shuffle : (Tensor, int) -> (Tensor)")
Expand Down
Loading

0 comments on commit c531f54

Please sign in to comment.